diff --git a/qdrant_client/async_qdrant_client.py b/qdrant_client/async_qdrant_client.py index b9c57ecf..210c1e56 100644 --- a/qdrant_client/async_qdrant_client.py +++ b/qdrant_client/async_qdrant_client.py @@ -12,6 +12,7 @@ import warnings from copy import deepcopy from typing import Any, Awaitable, Callable, Iterable, Mapping, Optional, Sequence, Union +import numpy as np from qdrant_client import grpc as grpc from qdrant_client.async_client_base import AsyncQdrantBase from qdrant_client.common.client_warnings import show_warning_once @@ -95,6 +96,7 @@ def __init__( Union[Callable[[], str], Callable[[], Awaitable[str]]] ] = None, cloud_inference: bool = False, + local_inference_batch_size: Optional[int] = None, check_compatibility: bool = True, **kwargs: Any, ): @@ -142,6 +144,7 @@ def __init__( "Cloud inference is not supported for local Qdrant, consider using FastEmbed or switch to Qdrant Cloud" ) self.cloud_inference = cloud_inference + self.local_inference_batch_size = local_inference_batch_size async def close(self, grpc_grace: Optional[float] = None, **kwargs: Any) -> None: """Closes the connection to Qdrant @@ -395,7 +398,11 @@ async def query_batch_points( requests = self._resolve_query_batch_request(requests) requires_inference = self._inference_inspector.inspect(requests) if requires_inference and (not self.cloud_inference): - requests = [self._embed_models(request) for request in requests] + requests = list( + self._embed_models( + requests, is_query=True, batch_size=self.local_inference_batch_size + ) + ) return await self._client.query_batch_points( collection_name=collection_name, requests=requests, @@ -522,10 +529,35 @@ async def query_points( query = self._resolve_query(query) requires_inference = self._inference_inspector.inspect([query, prefetch]) if requires_inference and (not self.cloud_inference): - query = self._embed_models(query, is_query=True) if query is not None else None - prefetch = ( - self._embed_models(prefetch, is_query=True) if prefetch is not None else None + query = ( + next( + iter( + self._embed_models( + query, is_query=True, batch_size=self.local_inference_batch_size + ) + ) + ) + if query is not None + else None ) + if isinstance(prefetch, list): + prefetch = list( + self._embed_models( + prefetch, is_query=True, batch_size=self.local_inference_batch_size + ) + ) + else: + prefetch = ( + next( + iter( + self._embed_models( + prefetch, is_query=True, batch_size=self.local_inference_batch_size + ) + ) + ) + if prefetch is not None + else None + ) return await self._client.query_points( collection_name=collection_name, query=query, @@ -661,10 +693,31 @@ async def query_points_groups( query = self._resolve_query(query) requires_inference = self._inference_inspector.inspect([query, prefetch]) if requires_inference and (not self.cloud_inference): - query = self._embed_models(query, is_query=True) if query is not None else None - prefetch = ( - self._embed_models(prefetch, is_query=True) if prefetch is not None else None + query = ( + next( + iter( + self._embed_models( + query, is_query=True, batch_size=self.local_inference_batch_size + ) + ) + ) + if query is not None + else None ) + if isinstance(prefetch, list): + prefetch = list( + self._embed_models( + prefetch, is_query=True, batch_size=self.local_inference_batch_size + ) + ) + elif prefetch is not None: + prefetch = next( + iter( + self._embed_models( + prefetch, is_query=True, batch_size=self.local_inference_batch_size + ) + ) + ) return await self._client.query_points_groups( collection_name=collection_name, query=query, @@ -1506,10 +1559,20 @@ async def upsert( ) requires_inference = self._inference_inspector.inspect(points) if requires_inference and (not self.cloud_inference): - if isinstance(points, list): - points = [self._embed_models(point, is_query=False) for point in points] + if isinstance(points, types.Batch): + points = next( + iter( + self._embed_models( + points, is_query=False, batch_size=self.local_inference_batch_size + ) + ) + ) else: - points = self._embed_models(points, is_query=False) + points = list( + self._embed_models( + points, is_query=False, batch_size=self.local_inference_batch_size + ) + ) return await self._client.upsert( collection_name=collection_name, points=points, @@ -1560,7 +1623,11 @@ async def update_vectors( assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}" requires_inference = self._inference_inspector.inspect(points) if requires_inference and (not self.cloud_inference): - points = [self._embed_models(point, is_query=False) for point in points] + points = list( + self._embed_models( + points, is_query=False, batch_size=self.local_inference_batch_size + ) + ) return await self._client.update_vectors( collection_name=collection_name, points=points, @@ -2000,9 +2067,11 @@ async def batch_update_points( assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}" requires_inference = self._inference_inspector.inspect(update_operations) if requires_inference and (not self.cloud_inference): - update_operations = [ - self._embed_models(op, is_query=False) for op in update_operations - ] + update_operations = list( + self._embed_models( + update_operations, is_query=False, batch_size=self.local_inference_batch_size + ) + ) return await self._client.batch_update_points( collection_name=collection_name, update_operations=update_operations, @@ -2426,7 +2495,25 @@ def upload_points( This parameter overwrites shard keys written in the records. """ + + def chain(*iterables: Iterable) -> Iterable: + for iterable in iterables: + yield from iterable + assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}" + if not self.cloud_inference: + iter_points = iter(points) + requires_inference = False + try: + point = next(iter_points) + requires_inference = self._inference_inspector.inspect(point) + points = chain(iter([point]), iter_points) + except (StopIteration, StopAsyncIteration): + points = [] + if requires_inference: + points = self._embed_models_strict( + points, parallel=parallel, batch_size=self.local_inference_batch_size + ) return self._client.upload_points( collection_name=collection_name, points=points, @@ -2478,7 +2565,26 @@ def upload_collection( If multiple shard_keys are provided, the update will be written to each of them. Only works for collections with `custom` sharding method. """ + + def chain(*iterables: Iterable) -> Iterable: + for iterable in iterables: + yield from iterable + assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}" + if not self.cloud_inference: + if not isinstance(vectors, dict) and (not isinstance(vectors, np.ndarray)): + requires_inference = False + try: + iter_vectors = iter(vectors) + vector = next(iter_vectors) + requires_inference = self._inference_inspector.inspect(vector) + vectors = chain(iter([vector]), iter_vectors) + except (StopIteration, StopAsyncIteration): + vectors = [] + if requires_inference: + vectors = self._embed_models_strict( + vectors, parallel=parallel, batch_size=self.local_inference_batch_size + ) return self._client.upload_collection( collection_name=collection_name, vectors=vectors, diff --git a/qdrant_client/async_qdrant_fastembed.py b/qdrant_client/async_qdrant_fastembed.py index eedbe7e0..05661500 100644 --- a/qdrant_client/async_qdrant_fastembed.py +++ b/qdrant_client/async_qdrant_fastembed.py @@ -13,85 +13,42 @@ from itertools import tee from typing import Any, Iterable, Optional, Sequence, Union, get_args from copy import deepcopy -from pathlib import Path import numpy as np from pydantic import BaseModel +from qdrant_client import grpc +from qdrant_client.common.client_warnings import show_warning from qdrant_client.async_client_base import AsyncQdrantBase +from qdrant_client.embed.model_embedder import ModelEmbedder +from qdrant_client.http import models from qdrant_client.conversions import common_types as types from qdrant_client.conversions.conversion import GrpcToRest from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES -from qdrant_client.embed.embed_inspector import InspectorEmbed -from qdrant_client.embed.models import NumericVector, NumericVectorStruct from qdrant_client.embed.schema_parser import ModelSchemaParser -from qdrant_client.embed.utils import FieldPath -from qdrant_client.fastembed_common import QueryResponse -from qdrant_client.http import models from qdrant_client.hybrid.fusion import reciprocal_rank_fusion -from qdrant_client import grpc -from qdrant_client.common.client_warnings import show_warning - -try: - from fastembed import ( - SparseTextEmbedding, - TextEmbedding, - LateInteractionTextEmbedding, - ImageEmbedding, - ) - from fastembed.common import OnnxProvider - from PIL import Image as PilImage -except ImportError: - TextEmbedding = None - SparseTextEmbedding = None - OnnxProvider = None - LateInteractionTextEmbedding = None - ImageEmbedding = None - PilImage = None -SUPPORTED_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = ( - { - model["model"]: (model["dim"], models.Distance.COSINE) - for model in TextEmbedding.list_supported_models() - } - if TextEmbedding - else {} -) -SUPPORTED_SPARSE_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = ( - {model["model"]: model for model in SparseTextEmbedding.list_supported_models()} - if SparseTextEmbedding - else {} -) -IDF_EMBEDDING_MODELS: set[str] = ( - { - model_config["model"] - for model_config in SparseTextEmbedding.list_supported_models() - if model_config.get("requires_idf", None) - } - if SparseTextEmbedding - else set() -) -_LATE_INTERACTION_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = ( - {model["model"]: model for model in LateInteractionTextEmbedding.list_supported_models()} - if LateInteractionTextEmbedding - else {} -) -_IMAGE_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = ( - {model["model"]: model for model in ImageEmbedding.list_supported_models()} - if ImageEmbedding - else {} +from qdrant_client.fastembed_common import ( + QueryResponse, + TextEmbedding, + LateInteractionTextEmbedding, + ImageEmbedding, + SparseTextEmbedding, + SUPPORTED_EMBEDDING_MODELS, + SUPPORTED_SPARSE_EMBEDDING_MODELS, + _LATE_INTERACTION_EMBEDDING_MODELS, + _IMAGE_EMBEDDING_MODELS, + IDF_EMBEDDING_MODELS, + OnnxProvider, ) class AsyncQdrantFastembedMixin(AsyncQdrantBase): DEFAULT_EMBEDDING_MODEL = "BAAI/bge-small-en" - embedding_models: dict[str, "TextEmbedding"] = {} - sparse_embedding_models: dict[str, "SparseTextEmbedding"] = {} - late_interaction_embedding_models: dict[str, "LateInteractionTextEmbedding"] = {} - image_embedding_models: dict[str, "ImageEmbedding"] = {} + DEFAULT_BATCH_SIZE = 16 _FASTEMBED_INSTALLED: bool def __init__(self, parser: ModelSchemaParser, **kwargs: Any): self._embedding_model_name: Optional[str] = None self._sparse_embedding_model_name: Optional[str] = None - self._embed_inspector = InspectorEmbed(parser=parser) + self._model_embedder = ModelEmbedder(parser=parser, **kwargs) try: from fastembed import SparseTextEmbedding, TextEmbedding @@ -164,6 +121,7 @@ def set_model( cuda=cuda, device_ids=device_ids, lazy_load=lazy_load, + deprecated=True, **kwargs, ) self._embedding_model_name = embedding_model_name @@ -214,6 +172,7 @@ def set_sparse_model( cuda=cuda, device_ids=device_ids, lazy_load=lazy_load, + deprecated=True, **kwargs, ) self._sparse_embedding_model_name = embedding_model_name @@ -229,111 +188,89 @@ def _import_fastembed(cls) -> None: @classmethod def _get_model_params(cls, model_name: str) -> tuple[int, models.Distance]: cls._import_fastembed() - if model_name not in SUPPORTED_EMBEDDING_MODELS: + if model_name in SUPPORTED_EMBEDDING_MODELS: + return SUPPORTED_EMBEDDING_MODELS[model_name] + if model_name in _LATE_INTERACTION_EMBEDDING_MODELS: + return _LATE_INTERACTION_EMBEDDING_MODELS[model_name] + if model_name in _IMAGE_EMBEDDING_MODELS: + return _IMAGE_EMBEDDING_MODELS[model_name] + if model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS: raise ValueError( - f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_EMBEDDING_MODELS}" + "Sparse embeddings do not return fixed embedding size and distance type" ) - return SUPPORTED_EMBEDDING_MODELS[model_name] + raise ValueError(f"Unsupported embedding model: {model_name}") - @classmethod def _get_or_init_model( - cls, + self, model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, providers: Optional[Sequence["OnnxProvider"]] = None, + deprecated: bool = False, **kwargs: Any, ) -> "TextEmbedding": - if model_name in cls.embedding_models: - return cls.embedding_models[model_name] - cls._import_fastembed() - if model_name not in SUPPORTED_EMBEDDING_MODELS: - raise ValueError( - f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_EMBEDDING_MODELS}" - ) - cls.embedding_models[model_name] = TextEmbedding( + self._import_fastembed() + return self._model_embedder.embedder.get_or_init_model( model_name=model_name, cache_dir=cache_dir, threads=threads, providers=providers, + deprecated=deprecated, **kwargs, ) - return cls.embedding_models[model_name] - @classmethod def _get_or_init_sparse_model( - cls, + self, model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, providers: Optional[Sequence["OnnxProvider"]] = None, + deprecated: bool = False, **kwargs: Any, ) -> "SparseTextEmbedding": - if model_name in cls.sparse_embedding_models: - return cls.sparse_embedding_models[model_name] - cls._import_fastembed() - if model_name not in SUPPORTED_SPARSE_EMBEDDING_MODELS: - raise ValueError( - f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_SPARSE_EMBEDDING_MODELS}" - ) - cls.sparse_embedding_models[model_name] = SparseTextEmbedding( + self._import_fastembed() + return self._model_embedder.embedder.get_or_init_sparse_model( model_name=model_name, cache_dir=cache_dir, threads=threads, providers=providers, + deprecated=deprecated, **kwargs, ) - return cls.sparse_embedding_models[model_name] - @classmethod def _get_or_init_late_interaction_model( - cls, + self, model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, providers: Optional[Sequence["OnnxProvider"]] = None, **kwargs: Any, ) -> "LateInteractionTextEmbedding": - if model_name in cls.late_interaction_embedding_models: - return cls.late_interaction_embedding_models[model_name] - cls._import_fastembed() - if model_name not in _LATE_INTERACTION_EMBEDDING_MODELS: - raise ValueError( - f"Unsupported embedding model: {model_name}. Supported models: {_LATE_INTERACTION_EMBEDDING_MODELS}" - ) - cls.late_interaction_embedding_models[model_name] = LateInteractionTextEmbedding( + self._import_fastembed() + return self._model_embedder.embedder.get_or_init_late_interaction_model( model_name=model_name, cache_dir=cache_dir, threads=threads, providers=providers, **kwargs, ) - return cls.late_interaction_embedding_models[model_name] - @classmethod def _get_or_init_image_model( - cls, + self, model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, providers: Optional[Sequence["OnnxProvider"]] = None, **kwargs: Any, ) -> "ImageEmbedding": - if model_name in cls.image_embedding_models: - return cls.image_embedding_models[model_name] - cls._import_fastembed() - if model_name not in _IMAGE_EMBEDDING_MODELS: - raise ValueError( - f"Unsupported embedding model: {model_name}. Supported models: {_IMAGE_EMBEDDING_MODELS}" - ) - cls.image_embedding_models[model_name] = ImageEmbedding( + self._import_fastembed() + return self._model_embedder.embedder.get_or_init_image_model( model_name=model_name, cache_dir=cache_dir, threads=threads, providers=providers, **kwargs, ) - return cls.image_embedding_models[model_name] def _embed_documents( self, @@ -343,7 +280,7 @@ def _embed_documents( embed_type: str = "default", parallel: Optional[int] = None, ) -> Iterable[tuple[str, list[float]]]: - embedding_model = self._get_or_init_model(model_name=embedding_model_name) + embedding_model = self._get_or_init_model(model_name=embedding_model_name, deprecated=True) (documents_a, documents_b) = tee(documents, 2) if embed_type == "passage": vectors_iter = embedding_model.passage_embed( @@ -369,7 +306,9 @@ def _sparse_embed_documents( batch_size: int = 32, parallel: Optional[int] = None, ) -> Iterable[types.SparseVector]: - sparse_embedding_model = self._get_or_init_sparse_model(model_name=embedding_model_name) + sparse_embedding_model = self._get_or_init_sparse_model( + model_name=embedding_model_name, deprecated=True + ) vectors_iter = sparse_embedding_model.embed( documents, batch_size=batch_size, parallel=parallel ) @@ -484,6 +423,24 @@ def _validate_collection_info(self, collection_info: models.CollectionInfo) -> N modifier == models.Modifier.IDF ), f"{self.sparse_embedding_model_name} requires modifier IDF, current modifier is {modifier}" + def get_embedding_size(self, model_name: Optional[str] = None) -> int: + """ + Get the size of the embeddings produced by the specified model. + + Args: + model_name: optional, the name of the model to get the embedding size for. If None, the default model will be used. + + Returns: + int: the size of the embeddings produced by the model. + """ + model_name = model_name or self.embedding_model_name + if model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS: + raise ValueError( + f"Sparse embeddings do not have a fixed embedding size. Current model: {model_name}" + ) + (embeddings_size, _) = self._get_model_params(model_name=model_name) + return embeddings_size + def get_fastembed_vector_params( self, on_disk: Optional[bool] = None, @@ -648,7 +605,9 @@ async def query( list[types.ScoredPoint]: List of scored points. """ - embedding_model_inst = self._get_or_init_model(model_name=self.embedding_model_name) + embedding_model_inst = self._get_or_init_model( + model_name=self.embedding_model_name, deprecated=True + ) embeddings = list(embedding_model_inst.query_embed(query=query_text)) query_vector = embeddings[0].tolist() if self.sparse_embedding_model_name is None: @@ -665,7 +624,7 @@ async def query( ) ) sparse_embedding_model_inst = self._get_or_init_sparse_model( - model_name=self.sparse_embedding_model_name + model_name=self.sparse_embedding_model_name, deprecated=True ) sparse_vector = list(sparse_embedding_model_inst.query_embed(query=query_text))[0] sparse_query_vector = models.SparseVector( @@ -722,7 +681,9 @@ async def query_batch( list[list[QueryResponse]]: List of lists of responses for each query text. """ - embedding_model_inst = self._get_or_init_model(model_name=self.embedding_model_name) + embedding_model_inst = self._get_or_init_model( + model_name=self.embedding_model_name, deprecated=True + ) query_vectors = list(embedding_model_inst.query_embed(query=query_texts)) requests = [] for vector in query_vectors: @@ -740,7 +701,7 @@ async def query_batch( responses = await self.search_batch(collection_name=collection_name, requests=requests) return [self._scored_points_to_query_responses(response) for response in responses] sparse_embedding_model_inst = self._get_or_init_sparse_model( - model_name=self.sparse_embedding_model_name + model_name=self.sparse_embedding_model_name, deprecated=True ) sparse_query_vectors = [ models.SparseVector( @@ -808,7 +769,7 @@ def _resolve_query( GrpcToRest.convert_point_id(query) if isinstance(query, grpc.PointId) else query ) return models.NearestQuery(nearest=query) - if isinstance(query, INFERENCE_OBJECT_TYPES): + if isinstance(query, get_args(INFERENCE_OBJECT_TYPES)): return models.NearestQuery(nearest=query) if query is None: return None @@ -841,166 +802,25 @@ def _resolve_query_batch_request( return [self._resolve_query_request(query) for query in requests] def _embed_models( - self, model: BaseModel, paths: Optional[list[FieldPath]] = None, is_query: bool = False - ) -> Union[BaseModel, NumericVector]: - """Embed model's fields requiring inference - - Args: - model: Qdrant http model containing fields to embed - paths: Path to fields to embed. E.g. [FieldPath(current="recommend", tail=[FieldPath(current="negative", tail=None)])] - is_query: Flag to determine which embed method to use. Defaults to False. - - Returns: - A deepcopy of the method with embedded fields - """ - if paths is None: - if isinstance(model, INFERENCE_OBJECT_TYPES): - return self._embed_raw_data(model, is_query=is_query) - model = deepcopy(model) - paths = self._embed_inspector.inspect(model) - for path in paths: - list_model = [model] if not isinstance(model, list) else model - for item in list_model: - current_model = getattr(item, path.current, None) - if current_model is None: - continue - if path.tail: - self._embed_models(current_model, path.tail, is_query=is_query) - else: - was_list = isinstance(current_model, list) - current_model = ( - [current_model] if not isinstance(current_model, list) else current_model - ) - embeddings = [ - self._embed_raw_data(data, is_query=is_query) for data in current_model - ] - if was_list: - setattr(item, path.current, embeddings) - else: - setattr(item, path.current, embeddings[0]) - return model - - @staticmethod - def _resolve_inference_object(data: models.VectorStruct) -> models.VectorStruct: - """Resolve inference object into a model - - Args: - data: models.VectorStruct - data to resolve, if it's an inference object, convert it to a proper type, - otherwise - keep unchanged - - Returns: - models.VectorStruct: resolved data - """ - if not isinstance(data, models.InferenceObject): - return data - model_name = data.model - value = data.object - options = data.options - if model_name in ( - *SUPPORTED_EMBEDDING_MODELS.keys(), - *SUPPORTED_SPARSE_EMBEDDING_MODELS.keys(), - *_LATE_INTERACTION_EMBEDDING_MODELS.keys(), - ): - return models.Document(model=model_name, text=value, options=options) - if model_name in _IMAGE_EMBEDDING_MODELS: - return models.Image(model=model_name, image=value, options=options) - raise ValueError(f"{model_name} is not among supported models") - - def _embed_raw_data( - self, data: models.VectorStruct, is_query: bool = False - ) -> NumericVectorStruct: - """Iterates over the data and calls inference on the fields requiring it - - Args: - data: models.VectorStruct - data to embed, if it's not a field which requires inference, leave it as is - is_query: Flag to determine which embed method to use. Defaults to False. - - Returns: - NumericVectorStruct: Embedded data - """ - data = self._resolve_inference_object(data) - if isinstance(data, models.Document): - return self._embed_document(data, is_query=is_query) - elif isinstance(data, models.Image): - return self._embed_image(data) - elif isinstance(data, dict): - return { - key: self._embed_raw_data(value, is_query=is_query) - for (key, value) in data.items() - } - elif isinstance(data, list): - if data and isinstance(data[0], float): - return data - return [self._embed_raw_data(value, is_query=is_query) for value in data] - return data - - def _embed_document(self, document: models.Document, is_query: bool = False) -> NumericVector: - """Embed a document using the specified embedding model - - Args: - document: Document to embed - is_query: Flag to determine which embed method to use. Defaults to False. - - Returns: - NumericVector: Document's embedding - - Raises: - ValueError: If model is not supported - """ - model_name = document.model - text = document.text - options = document.options or {} - if model_name in SUPPORTED_EMBEDDING_MODELS: - embedding_model_inst = self._get_or_init_model(model_name=model_name, **options) - if not is_query: - embedding = list(embedding_model_inst.embed(documents=[text]))[0].tolist() - else: - embedding = list(embedding_model_inst.query_embed(query=text))[0].tolist() - return embedding - elif model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS: - sparse_embedding_model_inst = self._get_or_init_sparse_model( - model_name=model_name, **options - ) - if not is_query: - sparse_embedding = list(sparse_embedding_model_inst.embed(documents=[text]))[0] - else: - sparse_embedding = list(sparse_embedding_model_inst.query_embed(query=text))[0] - return models.SparseVector( - indices=sparse_embedding.indices.tolist(), values=sparse_embedding.values.tolist() - ) - elif model_name in _LATE_INTERACTION_EMBEDDING_MODELS: - li_embedding_model_inst = self._get_or_init_late_interaction_model( - model_name=model_name, **options - ) - if not is_query: - embedding = list(li_embedding_model_inst.embed(documents=[text]))[0].tolist() - else: - embedding = list(li_embedding_model_inst.query_embed(query=text))[0].tolist() - return embedding - else: - raise ValueError(f"{model_name} is not among supported models") - - def _embed_image(self, image: models.Image) -> NumericVector: - """Embed an image using the specified embedding model - - Args: - image: Image to embed - - Returns: - NumericVector: Image's embedding + self, + raw_models: Union[BaseModel, Iterable[BaseModel]], + is_query: bool = False, + batch_size: Optional[int] = None, + ) -> Iterable[BaseModel]: + yield from self._model_embedder.embed_models( + raw_models=raw_models, + is_query=is_query, + batch_size=batch_size or self.DEFAULT_BATCH_SIZE, + ) - Raises: - ValueError: If model is not supported - """ - model_name = image.model - if model_name in _IMAGE_EMBEDDING_MODELS: - embedding_model_inst = self._get_or_init_image_model( - model_name=model_name, **image.options or {} - ) - if not isinstance(image.image, (str, Path, PilImage.Image)): - raise ValueError( - f"Unsupported image type: {type(image.image)}. Image: {image.image}" - ) - embedding = list(embedding_model_inst.embed(images=[image.image]))[0].tolist() - return embedding - raise ValueError(f"{model_name} is not among supported models") + def _embed_models_strict( + self, + raw_models: Iterable[Union[dict[str, BaseModel], BaseModel]], + batch_size: Optional[int] = None, + parallel: Optional[int] = None, + ) -> Iterable[BaseModel]: + yield from self._model_embedder.embed_models_strict( + raw_models=raw_models, + batch_size=batch_size or self.DEFAULT_BATCH_SIZE, + parallel=parallel, + ) diff --git a/qdrant_client/embed/common.py b/qdrant_client/embed/common.py index d3f4e44b..ca33fe78 100644 --- a/qdrant_client/embed/common.py +++ b/qdrant_client/embed/common.py @@ -1,8 +1,6 @@ -from typing import Type +from typing import Type, Union from qdrant_client.http import models INFERENCE_OBJECT_NAMES: set[str] = {"Document", "Image", "InferenceObject"} -INFERENCE_OBJECT_TYPES: tuple[ - Type[models.Document], Type[models.Image], Type[models.InferenceObject] -] = (models.Document, models.Image, models.InferenceObject) +INFERENCE_OBJECT_TYPES = Union[models.Document, models.Image, models.InferenceObject] diff --git a/qdrant_client/embed/embed_inspector.py b/qdrant_client/embed/embed_inspector.py index 386dfc2d..f0394876 100644 --- a/qdrant_client/embed/embed_inspector.py +++ b/qdrant_client/embed/embed_inspector.py @@ -1,5 +1,5 @@ from copy import copy -from typing import Union, Optional, Iterable +from typing import Union, Optional, Iterable, get_args from pydantic import BaseModel @@ -34,6 +34,9 @@ def inspect(self, points: Union[Iterable[BaseModel], BaseModel]) -> list[FieldPa if isinstance(points, BaseModel): self.parser.parse_model(points.__class__) paths.extend(self._inspect_model(points)) + elif isinstance(points, dict): + for value in points.values(): + paths.extend(self.inspect(value)) elif isinstance(points, Iterable): for point in points: if isinstance(point, BaseModel): @@ -113,7 +116,7 @@ def inspect_recursive(member: BaseModel, accumulator: str) -> list[str]: if model is None: return [] - if isinstance(model, INFERENCE_OBJECT_TYPES): + if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)): return [accum] if isinstance(model, BaseModel): @@ -133,7 +136,7 @@ def inspect_recursive(member: BaseModel, accumulator: str) -> list[str]: if not isinstance(current_model, BaseModel): continue - if isinstance(current_model, INFERENCE_OBJECT_TYPES): + if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)): found_paths.append(accum) found_paths.extend(inspect_recursive(current_model, accum)) @@ -158,7 +161,7 @@ def inspect_recursive(member: BaseModel, accumulator: str) -> list[str]: if not isinstance(current_model, BaseModel): continue - if isinstance(current_model, INFERENCE_OBJECT_TYPES): + if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)): found_paths.append(accum) found_paths.extend(inspect_recursive(current_model, accum)) diff --git a/qdrant_client/embed/embedder.py b/qdrant_client/embed/embedder.py new file mode 100644 index 00000000..b739e8f7 --- /dev/null +++ b/qdrant_client/embed/embedder.py @@ -0,0 +1,258 @@ +from collections import defaultdict +from typing import Optional, Sequence, Any, TypeVar, Generic +from pydantic import BaseModel + +from qdrant_client.http import models +from qdrant_client.embed.models import NumericVector +from qdrant_client.fastembed_common import ( + TextEmbedding, + SparseTextEmbedding, + LateInteractionTextEmbedding, + ImageEmbedding, + SUPPORTED_EMBEDDING_MODELS, + SUPPORTED_SPARSE_EMBEDDING_MODELS, + _LATE_INTERACTION_EMBEDDING_MODELS, + _IMAGE_EMBEDDING_MODELS, + OnnxProvider, + ImageInput, +) + + +T = TypeVar("T") + + +class ModelInstance(BaseModel, Generic[T], arbitrary_types_allowed=True): # type: ignore[call-arg] + model: T + options: dict[str, Any] + deprecated: bool = False + + +class Embedder: + def __init__(self, threads: Optional[int] = None, **kwargs: Any) -> None: + self.embedding_models: dict[str, list[ModelInstance[TextEmbedding]]] = defaultdict(list) + self.sparse_embedding_models: dict[str, list[ModelInstance[SparseTextEmbedding]]] = ( + defaultdict(list) + ) + self.late_interaction_embedding_models: dict[ + str, list[ModelInstance[LateInteractionTextEmbedding]] + ] = defaultdict(list) + self.image_embedding_models: dict[str, list[ModelInstance[ImageEmbedding]]] = defaultdict( + list + ) + self._threads = threads + + def get_or_init_model( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence["OnnxProvider"]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + deprecated: bool = False, + **kwargs: Any, + ) -> TextEmbedding: + if model_name not in SUPPORTED_EMBEDDING_MODELS: + raise ValueError( + f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_EMBEDDING_MODELS}" + ) + options = { + "cache_dir": cache_dir, + "threads": threads or self._threads, + "providers": providers, + "cuda": cuda, + "device_ids": device_ids, + **kwargs, + } + for instance in self.embedding_models[model_name]: + if (deprecated and instance.deprecated) or ( + not deprecated and instance.options == options + ): + return instance.model + + model = TextEmbedding(model_name=model_name, **options) + model_instance: ModelInstance[TextEmbedding] = ModelInstance( + model=model, options=options, deprecated=deprecated + ) + self.embedding_models[model_name].append(model_instance) + return model + + def get_or_init_sparse_model( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence["OnnxProvider"]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + deprecated: bool = False, + **kwargs: Any, + ) -> SparseTextEmbedding: + if model_name not in SUPPORTED_SPARSE_EMBEDDING_MODELS: + raise ValueError( + f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_SPARSE_EMBEDDING_MODELS}" + ) + + options = { + "cache_dir": cache_dir, + "threads": threads or self._threads, + "providers": providers, + "cuda": cuda, + "device_ids": device_ids, + **kwargs, + } + + for instance in self.sparse_embedding_models[model_name]: + if (deprecated and instance.deprecated) or ( + not deprecated and instance.options == options + ): + return instance.model + + model = SparseTextEmbedding(model_name=model_name, **options) + model_instance: ModelInstance[SparseTextEmbedding] = ModelInstance( + model=model, options=options, deprecated=deprecated + ) + self.sparse_embedding_models[model_name].append(model_instance) + return model + + def get_or_init_late_interaction_model( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence["OnnxProvider"]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + **kwargs: Any, + ) -> LateInteractionTextEmbedding: + if model_name not in _LATE_INTERACTION_EMBEDDING_MODELS: + raise ValueError( + f"Unsupported embedding model: {model_name}. Supported models: {_LATE_INTERACTION_EMBEDDING_MODELS}" + ) + options = { + "cache_dir": cache_dir, + "threads": threads or self._threads, + "providers": providers, + "cuda": cuda, + "device_ids": device_ids, + **kwargs, + } + + for instance in self.late_interaction_embedding_models[model_name]: + if instance.options == options: + return instance.model + + model = LateInteractionTextEmbedding(model_name=model_name, **options) + model_instance: ModelInstance[LateInteractionTextEmbedding] = ModelInstance( + model=model, options=options + ) + self.late_interaction_embedding_models[model_name].append(model_instance) + return model + + def get_or_init_image_model( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence["OnnxProvider"]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + **kwargs: Any, + ) -> ImageEmbedding: + if model_name not in _IMAGE_EMBEDDING_MODELS: + raise ValueError( + f"Unsupported embedding model: {model_name}. Supported models: {_IMAGE_EMBEDDING_MODELS}" + ) + options = { + "cache_dir": cache_dir, + "threads": threads or self._threads, + "providers": providers, + "cuda": cuda, + "device_ids": device_ids, + **kwargs, + } + + for instance in self.image_embedding_models[model_name]: + if instance.options == options: + return instance.model + + model = ImageEmbedding(model_name=model_name, **options) + model_instance: ModelInstance[ImageEmbedding] = ModelInstance(model=model, options=options) + self.image_embedding_models[model_name].append(model_instance) + return model + + def embed( + self, + model_name: str, + texts: Optional[list[str]] = None, + images: Optional[list[ImageInput]] = None, + options: Optional[dict[str, Any]] = None, + is_query: bool = False, + batch_size: int = 32, + ) -> NumericVector: + if (texts is None) is (images is None): + raise ValueError("Either documents or images should be provided") + if model_name in SUPPORTED_EMBEDDING_MODELS: + embedding_model_inst = self.get_or_init_model(model_name=model_name, **options or {}) + + if not is_query: + embeddings = [ + embedding.tolist() + for embedding in embedding_model_inst.embed( + documents=texts, batch_size=batch_size + ) + ] + else: + embeddings = [ + embedding.tolist() + for embedding in embedding_model_inst.query_embed(query=texts) + ] + elif model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS.keys(): + embedding_model_inst = self.get_or_init_sparse_model( + model_name=model_name, **options or {} + ) + if not is_query: + embeddings = [ + models.SparseVector( + indices=sparse_embedding.indices.tolist(), + values=sparse_embedding.values.tolist(), + ) + for sparse_embedding in embedding_model_inst.embed( + documents=texts, batch_size=batch_size + ) + ] + else: + embeddings = [ + models.SparseVector( + indices=sparse_embedding.indices.tolist(), + values=sparse_embedding.values.tolist(), + ) + for sparse_embedding in embedding_model_inst.query_embed(query=texts) + ] + + elif model_name in _LATE_INTERACTION_EMBEDDING_MODELS: + embedding_model_inst = self.get_or_init_late_interaction_model( + model_name=model_name, **options or {} + ) + if not is_query: + embeddings = [ + embedding.tolist() + for embedding in embedding_model_inst.embed( + documents=texts, batch_size=batch_size + ) + ] + else: + embeddings = [ + embedding.tolist() + for embedding in embedding_model_inst.query_embed(query=texts) + ] + else: + embedding_model_inst = self.get_or_init_image_model( + model_name=model_name, **options or {} + ) + embeddings = [ + embedding.tolist() + for embedding in embedding_model_inst.embed(images=images, batch_size=batch_size) + ] + + return embeddings diff --git a/qdrant_client/embed/model_embedder.py b/qdrant_client/embed/model_embedder.py new file mode 100644 index 00000000..44631067 --- /dev/null +++ b/qdrant_client/embed/model_embedder.py @@ -0,0 +1,389 @@ +import os +from collections import defaultdict +from copy import deepcopy +from multiprocessing import get_all_start_methods +from typing import Optional, Union, Iterable, Any, Type, get_args + +from pydantic import BaseModel + +from qdrant_client.http import models +from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES +from qdrant_client.embed.embed_inspector import InspectorEmbed +from qdrant_client.embed.embedder import Embedder +from qdrant_client.embed.models import NumericVector +from qdrant_client.embed.schema_parser import ModelSchemaParser +from qdrant_client.embed.utils import FieldPath +from qdrant_client.fastembed_common import ( + SUPPORTED_EMBEDDING_MODELS, + SUPPORTED_SPARSE_EMBEDDING_MODELS, + _LATE_INTERACTION_EMBEDDING_MODELS, + _IMAGE_EMBEDDING_MODELS, +) +from qdrant_client.parallel_processor import ParallelWorkerPool, Worker +from qdrant_client.uploader.uploader import iter_batch + + +class ModelEmbedderWorker(Worker): + def __init__(self, **kwargs: Any): + self.model_embedder = ModelEmbedder(**kwargs) + + @classmethod + def start(cls, **kwargs: Any) -> "ModelEmbedderWorker": + return cls(threads=1, **kwargs) + + def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]: + for idx, batch in items: + yield idx, list(self.model_embedder.embed_models_batch(batch)) + + +class ModelEmbedder: + MAX_INTERNAL_BATCH_SIZE = 64 + + def __init__(self, parser: Optional[ModelSchemaParser] = None, **kwargs: Any): + self._batch_accumulator: dict[str, list[INFERENCE_OBJECT_TYPES]] = {} + self._embed_storage: dict[str, list[NumericVector]] = {} + self._embed_inspector = InspectorEmbed(parser=parser) + self.embedder = Embedder(**kwargs) + + def embed_models( + self, + raw_models: Union[BaseModel, Iterable[BaseModel]], + is_query: bool = False, + batch_size: int = 16, + ) -> Iterable[BaseModel]: + """Embed raw data fields in models and return models with vectors + + If any of model fields required inference, a deepcopy of a model with computed embeddings is returned, + otherwise returns original models. + Args: + raw_models: Iterable[BaseModel] - models which can contain fields with raw data + is_query: bool - flag to determine which embed method to use. Defaults to False. + batch_size: int - batch size for inference + Returns: + list[BaseModel]: models with embedded fields + """ + if isinstance(raw_models, BaseModel): + raw_models = [raw_models] + for raw_models_batch in iter_batch(raw_models, batch_size): + yield from self.embed_models_batch(raw_models_batch, is_query) + + def embed_models_strict( + self, + raw_models: Iterable[Union[dict[str, BaseModel], BaseModel]], + batch_size: int = 16, + parallel: Optional[int] = None, + ) -> Iterable[Union[dict[str, BaseModel], BaseModel]]: + """Embed raw data fields in models and return models with vectors + + Requires every input sequences element to contain raw data fields to inference. + Does not accept ready vectors. + + Args: + raw_models: Iterable[BaseModel] - models which contain fields with raw data to inference + batch_size: int - batch size for inference + parallel: int - number of parallel processes to use. Defaults to None. + + Returns: + Iterable[Union[dict[str, BaseModel], BaseModel]]: models with embedded fields + """ + is_small = False + + if isinstance(raw_models, list): + if len(raw_models) < batch_size: + is_small = True + + if parallel is None or parallel == 1 or is_small: + for batch in iter_batch(raw_models, batch_size): + yield from self.embed_models_batch(batch) + else: + raw_models_batches = iter_batch( + raw_models, size=1 + ) # larger batch sizes do not help with data parallel + # on cpu. todo: adjust when multi-gpu is available + if parallel == 0: + parallel = os.cpu_count() + + start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" + assert parallel is not None # just a mypy complaint + pool = ParallelWorkerPool( + num_workers=parallel, + worker=self._get_worker_class(), + start_method=start_method, + max_internal_batch_size=self.MAX_INTERNAL_BATCH_SIZE, + ) + + for batch in pool.ordered_map(raw_models_batches): + yield from batch + + def embed_models_batch( + self, + raw_models: list[Union[dict[str, BaseModel], BaseModel]], + is_query: bool = False, + ) -> Iterable[BaseModel]: + """Embed a batch of models with raw data fields and return models with vectors + + If any of model fields required inference, a deepcopy of a model with computed embeddings is returned, + otherwise returns original models. + Args: + raw_models: list[Union[dict[str, BaseModel], BaseModel]] - models which can contain fields with raw data + is_query: bool - flag to determine which embed method to use. Defaults to False. + Returns: + Iterable[BaseModel]: models with embedded fields + """ + for raw_model in raw_models: + self._process_model(raw_model, is_query=is_query, accumulating=True) + + if not self._batch_accumulator: + yield from raw_models + else: + yield from ( + self._process_model(raw_model, is_query=is_query, accumulating=False) + for raw_model in raw_models + ) + + def _process_model( + self, + model: Union[dict[str, BaseModel], BaseModel], + paths: Optional[list[FieldPath]] = None, + is_query: bool = False, + accumulating: bool = False, + ) -> Union[dict[str, BaseModel], dict[str, NumericVector], BaseModel, NumericVector]: + """Embed model's fields requiring inference + + Args: + model: Qdrant http model containing fields to embed + paths: Path to fields to embed. E.g. [FieldPath(current="recommend", tail=[FieldPath(current="negative", tail=None)])] + is_query: Flag to determine which embed method to use. Defaults to False. + accumulating: Flag to determine if we are accumulating models for batch embedding. Defaults to False. + + Returns: + A deepcopy of the method with embedded fields + """ + + if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)): + if accumulating: + self._accumulate(model) # type: ignore + else: + return self._drain_accumulator(model, is_query=is_query) # type: ignore + + if paths is None: + model = deepcopy(model) if not accumulating else model + + if isinstance(model, dict): + for key, value in model.items(): + if accumulating: + self._process_model(value, paths, accumulating=True) + else: + model[key] = self._process_model( + value, paths, is_query=is_query, accumulating=False + ) + return model + + paths = paths if paths is not None else self._embed_inspector.inspect(model) + + for path in paths: + list_model = [model] if not isinstance(model, list) else model + for item in list_model: + current_model = getattr(item, path.current, None) + if current_model is None: + continue + if path.tail: + self._process_model( + current_model, path.tail, is_query=is_query, accumulating=accumulating + ) + else: + was_list = isinstance(current_model, list) + current_model = current_model if was_list else [current_model] + + if not accumulating: + embeddings = [ + self._drain_accumulator(data, is_query=is_query) + for data in current_model + ] + if was_list: + setattr(item, path.current, embeddings) + else: + setattr(item, path.current, embeddings[0]) + else: + for data in current_model: + self._accumulate(data) + return model + + def _accumulate(self, data: models.VectorStruct) -> None: + """Add data to batch accumulator + + Args: + data: models.VectorStruct - any vector struct data, if inference object types instances in `data` - add them + to the accumulator, otherwise - do nothing. `InferenceObject` instances are converted to proper types. + + Returns: + None + """ + if isinstance(data, dict): + for value in data.values(): + self._accumulate(value) + return None + + if isinstance(data, list): + for value in data: + if not isinstance(value, get_args(INFERENCE_OBJECT_TYPES)): # if value is a vector + return None + self._accumulate(value) + + if not isinstance(data, get_args(INFERENCE_OBJECT_TYPES)): + return None + + data = self._resolve_inference_object(data) + if data.model not in self._batch_accumulator: + self._batch_accumulator[data.model] = [] + self._batch_accumulator[data.model].append(data) + + def _drain_accumulator(self, data: models.VectorStruct, is_query: bool) -> models.VectorStruct: + """Drain accumulator and replaces inference objects with computed embeddings + It is assumed objects are traversed in the same order as they were added to the accumulator + + Args: + data: models.VectorStruct - any vector struct data, if inference object types instances in `data` - replace + them with computed embeddings. If embeddings haven't yet been computed - compute them and then replace + inference objects. + + Returns: + models.VectorStruct: data with replaced inference objects + """ + if isinstance(data, dict): + for key, value in data.items(): + data[key] = self._drain_accumulator(value, is_query=is_query) + return data + + if isinstance(data, list): + for i, value in enumerate(data): + if not isinstance(value, get_args(INFERENCE_OBJECT_TYPES)): # if value is vector + return data + + data[i] = self._drain_accumulator(value, is_query=is_query) + return data + + if not isinstance(data, get_args(INFERENCE_OBJECT_TYPES)): + return data + + if not self._embed_storage or not self._embed_storage.get(data.model, None): + self._embed_accumulator(is_query=is_query) + + return self._next_embed(data.model) + + def _embed_accumulator(self, is_query: bool = False) -> None: + """Embed all accumulated objects for all models + + Args: + is_query: bool - flag to determine which embed method to use. Defaults to False. + + Returns: + None + """ + + def embed( + objects: list[INFERENCE_OBJECT_TYPES], model_name: str, is_text: bool + ) -> list[NumericVector]: + """Assemble batches by groups, embeds and return embeddings in the original order""" + unique_options: list[dict[str, Any]] = [] + batches: list[Any] = [] + group_indices: dict[int, list[int]] = defaultdict(list) + for i, obj in enumerate(objects): + for j, options in enumerate(unique_options): + if options == obj.options: + group_indices[j].append(i) + batches[j].append(obj.text if is_text else obj.image) + break + else: + # Create a new group if no match is found + group_indices[len(unique_options)] = [i] + unique_options.append(obj.options) + batches.append([obj.text if is_text else obj.image]) + + embeds = [] + for i, options in enumerate(unique_options): + embeds.extend( + [ + embedding + for embedding in self.embedder.embed( + model_name=model_name, + texts=batches[i] if is_text else None, + images=batches[i] if not is_text else None, + is_query=is_query, + options=options or {}, + ) + ] + ) + + iter_embeds = iter(embeds) + ordered_embeddings: list[list[NumericVector]] = [[]] * len(objects) + for indices in group_indices.values(): + for index in indices: + ordered_embeddings[index] = next(iter_embeds) + return ordered_embeddings + + for model in self._batch_accumulator: + if model not in ( + *SUPPORTED_EMBEDDING_MODELS.keys(), + *SUPPORTED_SPARSE_EMBEDDING_MODELS.keys(), + *_LATE_INTERACTION_EMBEDDING_MODELS.keys(), + *_IMAGE_EMBEDDING_MODELS, + ): + raise ValueError(f"{model} is not among supported models") + + for model, data in self._batch_accumulator.items(): + if model in [ + *SUPPORTED_EMBEDDING_MODELS, + *SUPPORTED_SPARSE_EMBEDDING_MODELS, + *_LATE_INTERACTION_EMBEDDING_MODELS, + ]: + embeddings = embed(objects=data, model_name=model, is_text=True) + else: + embeddings = embed(objects=data, model_name=model, is_text=False) + + self._embed_storage[model] = embeddings + self._batch_accumulator.clear() + + def _next_embed(self, model_name: str) -> NumericVector: + """Get next computed embedding from embedded batch + + Args: + model_name: str - retrieve embedding from the storage by this model name + + Returns: + NumericVector: computed embedding + """ + return self._embed_storage[model_name].pop(0) + + @staticmethod + def _resolve_inference_object(data: models.VectorStruct) -> models.VectorStruct: + """Resolve inference object into a model + + Args: + data: models.VectorStruct - data to resolve, if it's an inference object, convert it to a proper type, + otherwise - keep unchanged + + Returns: + models.VectorStruct: resolved data + """ + + if not isinstance(data, models.InferenceObject): + return data + + model_name = data.model + value = data.object + options = data.options + if model_name in ( + *SUPPORTED_EMBEDDING_MODELS.keys(), + *SUPPORTED_SPARSE_EMBEDDING_MODELS.keys(), + *_LATE_INTERACTION_EMBEDDING_MODELS.keys(), + ): + return models.Document(model=model_name, text=value, options=options) + if model_name in _IMAGE_EMBEDDING_MODELS: + return models.Image(model=model_name, image=value, options=options) + + raise ValueError(f"{model_name} is not among supported models") + + @classmethod + def _get_worker_class(cls) -> Type[ModelEmbedderWorker]: + return ModelEmbedderWorker diff --git a/qdrant_client/embed/type_inspector.py b/qdrant_client/embed/type_inspector.py index d9c7fbe6..c49c62b5 100644 --- a/qdrant_client/embed/type_inspector.py +++ b/qdrant_client/embed/type_inspector.py @@ -1,8 +1,9 @@ -from typing import Union, Optional, Iterable +from typing import Union, Optional, Iterable, get_args from pydantic import BaseModel from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES + from qdrant_client.embed.schema_parser import ModelSchemaParser from qdrant_client.embed.utils import FieldPath from qdrant_client.http import models @@ -33,6 +34,11 @@ def inspect(self, points: Union[Iterable[BaseModel], BaseModel]) -> bool: self.parser.parse_model(points.__class__) return self._inspect_model(points) + elif isinstance(points, dict): + for value in points.values(): + if self.inspect(value): + return True + elif isinstance(points, Iterable): for point in points: if isinstance(point, BaseModel): @@ -42,7 +48,7 @@ def inspect(self, points: Union[Iterable[BaseModel], BaseModel]) -> bool: return False def _inspect_model(self, model: BaseModel, paths: Optional[list[FieldPath]] = None) -> bool: - if isinstance(model, INFERENCE_OBJECT_TYPES): + if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)): return True paths = ( @@ -81,7 +87,7 @@ def inspect_recursive(member: BaseModel) -> bool: if model is None: return False - if isinstance(model, INFERENCE_OBJECT_TYPES): + if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)): return True if isinstance(model, BaseModel): @@ -99,7 +105,7 @@ def inspect_recursive(member: BaseModel) -> bool: elif isinstance(model, list): for current_model in model: - if isinstance(current_model, INFERENCE_OBJECT_TYPES): + if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)): return True if not isinstance(current_model, BaseModel): @@ -122,7 +128,7 @@ def inspect_recursive(member: BaseModel) -> bool: for key, values in model.items(): values = [values] if not isinstance(values, list) else values for current_model in values: - if isinstance(current_model, INFERENCE_OBJECT_TYPES): + if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)): return True if not isinstance(current_model, BaseModel): diff --git a/qdrant_client/fastembed_common.py b/qdrant_client/fastembed_common.py index fe1b363c..595b907f 100644 --- a/qdrant_client/fastembed_common.py +++ b/qdrant_client/fastembed_common.py @@ -3,6 +3,67 @@ from pydantic import BaseModel, Field from qdrant_client.conversions.common_types import SparseVector +from qdrant_client.http import models + +try: + from fastembed import ( + SparseTextEmbedding, + TextEmbedding, + LateInteractionTextEmbedding, + ImageEmbedding, + ) + from fastembed.common import OnnxProvider, ImageInput +except ImportError: + TextEmbedding = None + SparseTextEmbedding = None + OnnxProvider = None + LateInteractionTextEmbedding = None + ImageEmbedding = None + ImageInput = None + + +SUPPORTED_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = ( + { + model["model"]: (model["dim"], models.Distance.COSINE) + for model in TextEmbedding.list_supported_models() + } + if TextEmbedding + else {} +) + +SUPPORTED_SPARSE_EMBEDDING_MODELS: dict[str, dict[str, Any]] = ( + {model["model"]: model for model in SparseTextEmbedding.list_supported_models()} + if SparseTextEmbedding + else {} +) + +IDF_EMBEDDING_MODELS: set[str] = ( + { + model_config["model"] + for model_config in SparseTextEmbedding.list_supported_models() + if model_config.get("requires_idf", None) + } + if SparseTextEmbedding + else set() +) + +_LATE_INTERACTION_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = ( + { + model["model"]: (model["dim"], models.Distance.COSINE) + for model in LateInteractionTextEmbedding.list_supported_models() + } + if LateInteractionTextEmbedding + else {} +) + +_IMAGE_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = ( + { + model["model"]: (model["dim"], models.Distance.COSINE) + for model in ImageEmbedding.list_supported_models() + } + if ImageEmbedding + else {} +) class QueryResponse(BaseModel, extra="forbid"): # type: ignore diff --git a/qdrant_client/parallel_processor.py b/qdrant_client/parallel_processor.py index 57ef7904..ef219488 100644 --- a/qdrant_client/parallel_processor.py +++ b/qdrant_client/parallel_processor.py @@ -1,5 +1,6 @@ import logging import os +from collections import defaultdict from enum import Enum from multiprocessing import Queue, get_context from multiprocessing.context import BaseContext @@ -11,7 +12,7 @@ # Single item should be processed in less than: processing_timeout = 10 * 60 # seconds -max_internal_batch_size = 200 +MAX_INTERNAL_BATCH_SIZE = 200 class QueueSignals(str, Enum): @@ -22,7 +23,7 @@ class QueueSignals(str, Enum): class Worker: @classmethod - def start(cls, **kwargs: Any) -> "Worker": + def start(cls, *args: Any, **kwargs: Any) -> "Worker": raise NotImplementedError() def process(self, items: Iterable[Any]) -> Iterable[Any]: @@ -84,7 +85,13 @@ def input_queue_iterable() -> Iterable[Any]: class ParallelWorkerPool: - def __init__(self, num_workers: int, worker: Type[Worker], start_method: Optional[str] = None): + def __init__( + self, + num_workers: int, + worker: Type[Worker], + start_method: Optional[str] = None, + max_internal_batch_size: int = MAX_INTERNAL_BATCH_SIZE, + ): self.worker_class = worker self.num_workers = num_workers self.input_queue: Optional[Queue] = None @@ -129,6 +136,7 @@ def unordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Ite pushed = 0 read = 0 for item in stream: + self.check_worker_health() if pushed - read < self.queue_size: try: out_item = self.output_queue.get_nowait() @@ -147,7 +155,6 @@ def unordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Ite raise RuntimeError("Thread unexpectedly terminated") yield out_item read += 1 - self.input_queue.put(item) pushed += 1 @@ -174,6 +181,31 @@ def unordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Ite self.input_queue.join_thread() self.output_queue.join_thread() + def semi_ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]: + return self.unordered_map(enumerate(stream), *args, **kwargs) + + def ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]: + buffer = defaultdict(int) + next_expected = 0 + + for idx, item in self.semi_ordered_map(stream, *args, **kwargs): + buffer[idx] = item + while next_expected in buffer: + yield buffer.pop(next_expected) + next_expected += 1 + + def check_worker_health(self) -> None: + """ + Checks if any worker process has terminated unexpectedly + """ + for process in self.processes: + if not process.is_alive() and process.exitcode != 0: + self.emergency_shutdown = True + self.join_or_terminate() + raise RuntimeError( + f"Worker PID: {process.pid} terminated unexpectedly with code {process.exitcode}" + ) + def join_or_terminate(self, timeout: Optional[int] = 1) -> None: """ Emergency shutdown diff --git a/qdrant_client/qdrant_client.py b/qdrant_client/qdrant_client.py index 3691052d..eecc4e41 100644 --- a/qdrant_client/qdrant_client.py +++ b/qdrant_client/qdrant_client.py @@ -11,6 +11,8 @@ Union, ) +import numpy as np + from qdrant_client import grpc as grpc from qdrant_client.client_base import QdrantBase from qdrant_client.common.client_warnings import show_warning_once @@ -95,6 +97,7 @@ def __init__( Union[Callable[[], str], Callable[[], Awaitable[str]]] ] = None, cloud_inference: bool = False, + local_inference_batch_size: Optional[int] = None, check_compatibility: bool = True, **kwargs: Any, ): @@ -155,6 +158,7 @@ def __init__( "Cloud inference is not supported for local Qdrant, consider using FastEmbed or switch to Qdrant Cloud" ) self.cloud_inference = cloud_inference + self.local_inference_batch_size = local_inference_batch_size def __del__(self) -> None: self.close() @@ -418,7 +422,11 @@ def query_batch_points( requests = self._resolve_query_batch_request(requests) requires_inference = self._inference_inspector.inspect(requests) if requires_inference and not self.cloud_inference: - requests = [self._embed_models(request) for request in requests] + requests = list( + self._embed_models( + requests, is_query=True, batch_size=self.local_inference_batch_size + ) + ) return self._client.query_batch_points( collection_name=collection_name, @@ -550,10 +558,35 @@ def query_points( query = self._resolve_query(query) requires_inference = self._inference_inspector.inspect([query, prefetch]) if requires_inference and not self.cloud_inference: - query = self._embed_models(query, is_query=True) if query is not None else None - prefetch = ( - self._embed_models(prefetch, is_query=True) if prefetch is not None else None + query = ( + next( + iter( + self._embed_models( + query, is_query=True, batch_size=self.local_inference_batch_size + ) + ) + ) + if query is not None + else None ) + if isinstance(prefetch, list): + prefetch = list( + self._embed_models( + prefetch, is_query=True, batch_size=self.local_inference_batch_size + ) + ) + else: + prefetch = ( + next( + iter( + self._embed_models( + prefetch, is_query=True, batch_size=self.local_inference_batch_size + ) + ) + ) + if prefetch is not None + else None + ) return self._client.query_points( collection_name=collection_name, @@ -694,10 +727,31 @@ def query_points_groups( query = self._resolve_query(query) requires_inference = self._inference_inspector.inspect([query, prefetch]) if requires_inference and not self.cloud_inference: - query = self._embed_models(query, is_query=True) if query is not None else None - prefetch = ( - self._embed_models(prefetch, is_query=True) if prefetch is not None else None + query = ( + next( + iter( + self._embed_models( + query, is_query=True, batch_size=self.local_inference_batch_size + ) + ) + ) + if query is not None + else None ) + if isinstance(prefetch, list): + prefetch = list( + self._embed_models( + prefetch, is_query=True, batch_size=self.local_inference_batch_size + ) + ) + elif prefetch is not None: + prefetch = next( + iter( + self._embed_models( + prefetch, is_query=True, batch_size=self.local_inference_batch_size + ) + ) + ) return self._client.query_points_groups( collection_name=collection_name, @@ -1559,10 +1613,20 @@ def upsert( requires_inference = self._inference_inspector.inspect(points) if requires_inference and not self.cloud_inference: - if isinstance(points, list): - points = [self._embed_models(point, is_query=False) for point in points] + if isinstance(points, types.Batch): + points = next( + iter( + self._embed_models( + points, is_query=False, batch_size=self.local_inference_batch_size + ) + ) + ) else: - points = self._embed_models(points, is_query=False) + points = list( + self._embed_models( + points, is_query=False, batch_size=self.local_inference_batch_size + ) + ) return self._client.upsert( collection_name=collection_name, @@ -1615,7 +1679,11 @@ def update_vectors( requires_inference = self._inference_inspector.inspect(points) if requires_inference and not self.cloud_inference: - points = [self._embed_models(point, is_query=False) for point in points] + points = list( + self._embed_models( + points, is_query=False, batch_size=self.local_inference_batch_size + ) + ) return self._client.update_vectors( collection_name=collection_name, @@ -2065,9 +2133,11 @@ def batch_update_points( assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}" requires_inference = self._inference_inspector.inspect(update_operations) if requires_inference and not self.cloud_inference: - update_operations = [ - self._embed_models(op, is_query=False) for op in update_operations - ] + update_operations = list( + self._embed_models( + update_operations, is_query=False, batch_size=self.local_inference_batch_size + ) + ) return self._client.batch_update_points( collection_name=collection_name, @@ -2512,8 +2582,28 @@ def upload_points( This parameter overwrites shard keys written in the records. """ + + def chain(*iterables: Iterable) -> Iterable: + for iterable in iterables: + yield from iterable + assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}" + if not self.cloud_inference: + iter_points = iter(points) + requires_inference = False + try: + point = next(iter_points) + requires_inference = self._inference_inspector.inspect(point) + points = chain(iter([point]), iter_points) + except (StopIteration, StopAsyncIteration): + points = [] + + if requires_inference: + points = self._embed_models_strict( + points, parallel=parallel, batch_size=self.local_inference_batch_size + ) + return self._client.upload_points( collection_name=collection_name, points=points, @@ -2567,8 +2657,29 @@ def upload_collection( If multiple shard_keys are provided, the update will be written to each of them. Only works for collections with `custom` sharding method. """ + + def chain(*iterables: Iterable) -> Iterable: + for iterable in iterables: + yield from iterable + assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}" + if not self.cloud_inference: + if not isinstance(vectors, dict) and not isinstance(vectors, np.ndarray): + requires_inference = False + try: + iter_vectors = iter(vectors) + vector = next(iter_vectors) + requires_inference = self._inference_inspector.inspect(vector) + vectors = chain(iter([vector]), iter_vectors) + except (StopIteration, StopAsyncIteration): + vectors = [] + + if requires_inference: + vectors = self._embed_models_strict( + vectors, parallel=parallel, batch_size=self.local_inference_batch_size + ) + return self._client.upload_collection( collection_name=collection_name, vectors=vectors, diff --git a/qdrant_client/qdrant_fastembed.py b/qdrant_client/qdrant_fastembed.py index 03c665fd..30a4cb24 100644 --- a/qdrant_client/qdrant_fastembed.py +++ b/qdrant_client/qdrant_fastembed.py @@ -2,95 +2,45 @@ from itertools import tee from typing import Any, Iterable, Optional, Sequence, Union, get_args from copy import deepcopy -from pathlib import Path import numpy as np - from pydantic import BaseModel +from qdrant_client import grpc +from qdrant_client.common.client_warnings import show_warning from qdrant_client.client_base import QdrantBase +from qdrant_client.embed.model_embedder import ModelEmbedder +from qdrant_client.http import models from qdrant_client.conversions import common_types as types from qdrant_client.conversions.conversion import GrpcToRest from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES -from qdrant_client.embed.embed_inspector import InspectorEmbed -from qdrant_client.embed.models import NumericVector, NumericVectorStruct from qdrant_client.embed.schema_parser import ModelSchemaParser -from qdrant_client.embed.utils import FieldPath -from qdrant_client.fastembed_common import QueryResponse -from qdrant_client.http import models from qdrant_client.hybrid.fusion import reciprocal_rank_fusion -from qdrant_client import grpc -from qdrant_client.common.client_warnings import show_warning - -try: - from fastembed import ( - SparseTextEmbedding, - TextEmbedding, - LateInteractionTextEmbedding, - ImageEmbedding, - ) - from fastembed.common import OnnxProvider - from PIL import Image as PilImage -except ImportError: - TextEmbedding = None - SparseTextEmbedding = None - OnnxProvider = None - LateInteractionTextEmbedding = None - ImageEmbedding = None - PilImage = None - - -SUPPORTED_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = ( - { - model["model"]: (model["dim"], models.Distance.COSINE) - for model in TextEmbedding.list_supported_models() - } - if TextEmbedding - else {} -) - -SUPPORTED_SPARSE_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = ( - {model["model"]: model for model in SparseTextEmbedding.list_supported_models()} - if SparseTextEmbedding - else {} -) - -IDF_EMBEDDING_MODELS: set[str] = ( - { - model_config["model"] - for model_config in SparseTextEmbedding.list_supported_models() - if model_config.get("requires_idf", None) - } - if SparseTextEmbedding - else set() -) - -_LATE_INTERACTION_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = ( - {model["model"]: model for model in LateInteractionTextEmbedding.list_supported_models()} - if LateInteractionTextEmbedding - else {} -) - -_IMAGE_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = ( - {model["model"]: model for model in ImageEmbedding.list_supported_models()} - if ImageEmbedding - else {} +from qdrant_client.fastembed_common import ( + QueryResponse, + TextEmbedding, + LateInteractionTextEmbedding, + ImageEmbedding, + SparseTextEmbedding, + SUPPORTED_EMBEDDING_MODELS, + SUPPORTED_SPARSE_EMBEDDING_MODELS, + _LATE_INTERACTION_EMBEDDING_MODELS, + _IMAGE_EMBEDDING_MODELS, + IDF_EMBEDDING_MODELS, + OnnxProvider, ) class QdrantFastembedMixin(QdrantBase): DEFAULT_EMBEDDING_MODEL = "BAAI/bge-small-en" - - embedding_models: dict[str, "TextEmbedding"] = {} - sparse_embedding_models: dict[str, "SparseTextEmbedding"] = {} - late_interaction_embedding_models: dict[str, "LateInteractionTextEmbedding"] = {} - image_embedding_models: dict[str, "ImageEmbedding"] = {} + DEFAULT_BATCH_SIZE = 16 _FASTEMBED_INSTALLED: bool def __init__(self, parser: ModelSchemaParser, **kwargs: Any): self._embedding_model_name: Optional[str] = None self._sparse_embedding_model_name: Optional[str] = None - self._embed_inspector = InspectorEmbed(parser=parser) + self._model_embedder = ModelEmbedder(parser=parser, **kwargs) + try: from fastembed import SparseTextEmbedding, TextEmbedding @@ -168,6 +118,7 @@ def set_model( cuda=cuda, device_ids=device_ids, lazy_load=lazy_load, + deprecated=True, **kwargs, ) self._embedding_model_name = embedding_model_name @@ -218,6 +169,7 @@ def set_sparse_model( cuda=cuda, device_ids=device_ids, lazy_load=lazy_load, + deprecated=True, **kwargs, ) self._sparse_embedding_model_name = embedding_model_name @@ -237,124 +189,96 @@ def _import_fastembed(cls) -> None: def _get_model_params(cls, model_name: str) -> tuple[int, models.Distance]: cls._import_fastembed() - if model_name not in SUPPORTED_EMBEDDING_MODELS: + if model_name in SUPPORTED_EMBEDDING_MODELS: + return SUPPORTED_EMBEDDING_MODELS[model_name] + + if model_name in _LATE_INTERACTION_EMBEDDING_MODELS: + return _LATE_INTERACTION_EMBEDDING_MODELS[model_name] + + if model_name in _IMAGE_EMBEDDING_MODELS: + return _IMAGE_EMBEDDING_MODELS[model_name] + + if model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS: raise ValueError( - f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_EMBEDDING_MODELS}" + "Sparse embeddings do not return fixed embedding size and distance type" ) - return SUPPORTED_EMBEDDING_MODELS[model_name] + raise ValueError(f"Unsupported embedding model: {model_name}") - @classmethod def _get_or_init_model( - cls, + self, model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, providers: Optional[Sequence["OnnxProvider"]] = None, + deprecated: bool = False, **kwargs: Any, ) -> "TextEmbedding": - if model_name in cls.embedding_models: - return cls.embedding_models[model_name] - - cls._import_fastembed() - - if model_name not in SUPPORTED_EMBEDDING_MODELS: - raise ValueError( - f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_EMBEDDING_MODELS}" - ) + self._import_fastembed() - cls.embedding_models[model_name] = TextEmbedding( + return self._model_embedder.embedder.get_or_init_model( model_name=model_name, cache_dir=cache_dir, threads=threads, providers=providers, + deprecated=deprecated, **kwargs, ) - return cls.embedding_models[model_name] - @classmethod def _get_or_init_sparse_model( - cls, + self, model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, providers: Optional[Sequence["OnnxProvider"]] = None, + deprecated: bool = False, **kwargs: Any, ) -> "SparseTextEmbedding": - if model_name in cls.sparse_embedding_models: - return cls.sparse_embedding_models[model_name] + self._import_fastembed() - cls._import_fastembed() - - if model_name not in SUPPORTED_SPARSE_EMBEDDING_MODELS: - raise ValueError( - f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_SPARSE_EMBEDDING_MODELS}" - ) - - cls.sparse_embedding_models[model_name] = SparseTextEmbedding( + return self._model_embedder.embedder.get_or_init_sparse_model( model_name=model_name, cache_dir=cache_dir, threads=threads, providers=providers, + deprecated=deprecated, **kwargs, ) - return cls.sparse_embedding_models[model_name] - @classmethod def _get_or_init_late_interaction_model( - cls, + self, model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, providers: Optional[Sequence["OnnxProvider"]] = None, **kwargs: Any, ) -> "LateInteractionTextEmbedding": - if model_name in cls.late_interaction_embedding_models: - return cls.late_interaction_embedding_models[model_name] - - cls._import_fastembed() - - if model_name not in _LATE_INTERACTION_EMBEDDING_MODELS: - raise ValueError( - f"Unsupported embedding model: {model_name}. Supported models: {_LATE_INTERACTION_EMBEDDING_MODELS}" - ) - - cls.late_interaction_embedding_models[model_name] = LateInteractionTextEmbedding( + self._import_fastembed() + return self._model_embedder.embedder.get_or_init_late_interaction_model( model_name=model_name, cache_dir=cache_dir, threads=threads, providers=providers, **kwargs, ) - return cls.late_interaction_embedding_models[model_name] - @classmethod def _get_or_init_image_model( - cls, + self, model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, providers: Optional[Sequence["OnnxProvider"]] = None, **kwargs: Any, ) -> "ImageEmbedding": - if model_name in cls.image_embedding_models: - return cls.image_embedding_models[model_name] - - cls._import_fastembed() + self._import_fastembed() - if model_name not in _IMAGE_EMBEDDING_MODELS: - raise ValueError( - f"Unsupported embedding model: {model_name}. Supported models: {_IMAGE_EMBEDDING_MODELS}" - ) - - cls.image_embedding_models[model_name] = ImageEmbedding( + return self._model_embedder.embedder.get_or_init_image_model( model_name=model_name, cache_dir=cache_dir, threads=threads, providers=providers, **kwargs, ) - return cls.image_embedding_models[model_name] def _embed_documents( self, @@ -364,7 +288,7 @@ def _embed_documents( embed_type: str = "default", parallel: Optional[int] = None, ) -> Iterable[tuple[str, list[float]]]: - embedding_model = self._get_or_init_model(model_name=embedding_model_name) + embedding_model = self._get_or_init_model(model_name=embedding_model_name, deprecated=True) documents_a, documents_b = tee(documents, 2) if embed_type == "passage": vectors_iter = embedding_model.passage_embed( @@ -391,7 +315,9 @@ def _sparse_embed_documents( batch_size: int = 32, parallel: Optional[int] = None, ) -> Iterable[types.SparseVector]: - sparse_embedding_model = self._get_or_init_sparse_model(model_name=embedding_model_name) + sparse_embedding_model = self._get_or_init_sparse_model( + model_name=embedding_model_name, deprecated=True + ) vectors_iter = sparse_embedding_model.embed( documents, batch_size=batch_size, parallel=parallel @@ -523,6 +449,27 @@ def _validate_collection_info(self, collection_info: models.CollectionInfo) -> N modifier == models.Modifier.IDF ), f"{self.sparse_embedding_model_name} requires modifier IDF, current modifier is {modifier}" + def get_embedding_size( + self, + model_name: Optional[str] = None, + ) -> int: + """ + Get the size of the embeddings produced by the specified model. + + Args: + model_name: optional, the name of the model to get the embedding size for. If None, the default model will be used. + + Returns: + int: the size of the embeddings produced by the model. + """ + model_name = model_name or self.embedding_model_name + if model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS: + raise ValueError( + f"Sparse embeddings do not have a fixed embedding size. Current model: {model_name}" + ) + embeddings_size, _ = self._get_model_params(model_name=model_name) + return embeddings_size + def get_fastembed_vector_params( self, on_disk: Optional[bool] = None, @@ -705,7 +652,9 @@ def query( """ - embedding_model_inst = self._get_or_init_model(model_name=self.embedding_model_name) + embedding_model_inst = self._get_or_init_model( + model_name=self.embedding_model_name, deprecated=True + ) embeddings = list(embedding_model_inst.query_embed(query=query_text)) query_vector = embeddings[0].tolist() @@ -724,7 +673,7 @@ def query( ) sparse_embedding_model_inst = self._get_or_init_sparse_model( - model_name=self.sparse_embedding_model_name + model_name=self.sparse_embedding_model_name, deprecated=True ) sparse_vector = list(sparse_embedding_model_inst.query_embed(query=query_text))[0] sparse_query_vector = models.SparseVector( @@ -788,7 +737,9 @@ def query_batch( list[list[QueryResponse]]: List of lists of responses for each query text. """ - embedding_model_inst = self._get_or_init_model(model_name=self.embedding_model_name) + embedding_model_inst = self._get_or_init_model( + model_name=self.embedding_model_name, deprecated=True + ) query_vectors = list(embedding_model_inst.query_embed(query=query_texts)) requests = [] for vector in query_vectors: @@ -812,7 +763,7 @@ def query_batch( return [self._scored_points_to_query_responses(response) for response in responses] sparse_embedding_model_inst = self._get_or_init_sparse_model( - model_name=self.sparse_embedding_model_name + model_name=self.sparse_embedding_model_name, deprecated=True ) sparse_query_vectors = [ models.SparseVector( @@ -893,7 +844,7 @@ def _resolve_query( ) return models.NearestQuery(nearest=query) - if isinstance(query, INFERENCE_OBJECT_TYPES): + if isinstance(query, get_args(INFERENCE_OBJECT_TYPES)): return models.NearestQuery(nearest=query) if query is None: @@ -929,178 +880,24 @@ def _resolve_query_batch_request( def _embed_models( self, - model: BaseModel, - paths: Optional[list[FieldPath]] = None, + raw_models: Union[BaseModel, Iterable[BaseModel]], is_query: bool = False, - ) -> Union[BaseModel, NumericVector]: - """Embed model's fields requiring inference - - Args: - model: Qdrant http model containing fields to embed - paths: Path to fields to embed. E.g. [FieldPath(current="recommend", tail=[FieldPath(current="negative", tail=None)])] - is_query: Flag to determine which embed method to use. Defaults to False. - - Returns: - A deepcopy of the method with embedded fields - """ - if paths is None: - if isinstance(model, INFERENCE_OBJECT_TYPES): - return self._embed_raw_data(model, is_query=is_query) - model = deepcopy(model) - paths = self._embed_inspector.inspect(model) - for path in paths: - list_model = [model] if not isinstance(model, list) else model - for item in list_model: - current_model = getattr(item, path.current, None) - if current_model is None: - continue - if path.tail: - self._embed_models(current_model, path.tail, is_query=is_query) - else: - was_list = isinstance(current_model, list) - current_model = ( - [current_model] if not isinstance(current_model, list) else current_model - ) - embeddings = [ - self._embed_raw_data(data, is_query=is_query) for data in current_model - ] - if was_list: - setattr(item, path.current, embeddings) - else: - setattr(item, path.current, embeddings[0]) - return model - - @staticmethod - def _resolve_inference_object(data: models.VectorStruct) -> models.VectorStruct: - """Resolve inference object into a model - - Args: - data: models.VectorStruct - data to resolve, if it's an inference object, convert it to a proper type, - otherwise - keep unchanged - - Returns: - models.VectorStruct: resolved data - """ - - if not isinstance(data, models.InferenceObject): - return data - - model_name = data.model - value = data.object - options = data.options - if model_name in ( - *SUPPORTED_EMBEDDING_MODELS.keys(), - *SUPPORTED_SPARSE_EMBEDDING_MODELS.keys(), - *_LATE_INTERACTION_EMBEDDING_MODELS.keys(), - ): - return models.Document(model=model_name, text=value, options=options) - if model_name in _IMAGE_EMBEDDING_MODELS: - return models.Image(model=model_name, image=value, options=options) - - raise ValueError(f"{model_name} is not among supported models") + batch_size: Optional[int] = None, + ) -> Iterable[BaseModel]: + yield from self._model_embedder.embed_models( + raw_models=raw_models, + is_query=is_query, + batch_size=batch_size or self.DEFAULT_BATCH_SIZE, + ) - def _embed_raw_data( + def _embed_models_strict( self, - data: models.VectorStruct, - is_query: bool = False, - ) -> NumericVectorStruct: - """Iterates over the data and calls inference on the fields requiring it - - Args: - data: models.VectorStruct - data to embed, if it's not a field which requires inference, leave it as is - is_query: Flag to determine which embed method to use. Defaults to False. - - Returns: - NumericVectorStruct: Embedded data - """ - data = self._resolve_inference_object(data) - - if isinstance(data, models.Document): - return self._embed_document(data, is_query=is_query) - elif isinstance(data, models.Image): - return self._embed_image(data) - elif isinstance(data, dict): - return { - key: self._embed_raw_data(value, is_query=is_query) for key, value in data.items() - } - elif isinstance(data, list): - # we don't want to iterate over a vector - if data and isinstance(data[0], float): - return data - return [self._embed_raw_data(value, is_query=is_query) for value in data] - return data - - def _embed_document(self, document: models.Document, is_query: bool = False) -> NumericVector: - """Embed a document using the specified embedding model - - Args: - document: Document to embed - is_query: Flag to determine which embed method to use. Defaults to False. - - Returns: - NumericVector: Document's embedding - - Raises: - ValueError: If model is not supported - """ - model_name = document.model - text = document.text - options = document.options or {} - if model_name in SUPPORTED_EMBEDDING_MODELS: - embedding_model_inst = self._get_or_init_model(model_name=model_name, **options) - if not is_query: - embedding = list(embedding_model_inst.embed(documents=[text]))[0].tolist() - else: - embedding = list(embedding_model_inst.query_embed(query=text))[0].tolist() - return embedding - elif model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS: - sparse_embedding_model_inst = self._get_or_init_sparse_model( - model_name=model_name, **options - ) - if not is_query: - sparse_embedding = list(sparse_embedding_model_inst.embed(documents=[text]))[0] - else: - sparse_embedding = list(sparse_embedding_model_inst.query_embed(query=text))[0] - - return models.SparseVector( - indices=sparse_embedding.indices.tolist(), values=sparse_embedding.values.tolist() - ) - elif model_name in _LATE_INTERACTION_EMBEDDING_MODELS: - li_embedding_model_inst = self._get_or_init_late_interaction_model( - model_name=model_name, **options - ) - if not is_query: - embedding = list(li_embedding_model_inst.embed(documents=[text]))[0].tolist() - else: - embedding = list(li_embedding_model_inst.query_embed(query=text))[0].tolist() - return embedding - else: - raise ValueError(f"{model_name} is not among supported models") - - def _embed_image(self, image: models.Image) -> NumericVector: - """Embed an image using the specified embedding model - - Args: - image: Image to embed - - Returns: - NumericVector: Image's embedding - - Raises: - ValueError: If model is not supported - """ - model_name = image.model - if model_name in _IMAGE_EMBEDDING_MODELS: - embedding_model_inst = self._get_or_init_image_model( - model_name=model_name, **(image.options or {}) - ) - if not isinstance(image.image, (str, Path, PilImage.Image)): # type: ignore - # PilImage is None if PIL is not installed, - # but we'll fail earlier if it's not installed. - raise ValueError( - f"Unsupported image type: {type(image.image)}. Image: {image.image}" - ) - embedding = list(embedding_model_inst.embed(images=[image.image]))[0].tolist() - return embedding - - raise ValueError(f"{model_name} is not among supported models") + raw_models: Iterable[Union[dict[str, BaseModel], BaseModel]], + batch_size: Optional[int] = None, + parallel: Optional[int] = None, + ) -> Iterable[BaseModel]: + yield from self._model_embedder.embed_models_strict( + raw_models=raw_models, + batch_size=batch_size or self.DEFAULT_BATCH_SIZE, + parallel=parallel, + ) diff --git a/qdrant_client/uploader/grpc_uploader.py b/qdrant_client/uploader/grpc_uploader.py index b47676d2..d729b189 100644 --- a/qdrant_client/uploader/grpc_uploader.py +++ b/qdrant_client/uploader/grpc_uploader.py @@ -115,5 +115,5 @@ def process_upload(self, items: Iterable[Any]) -> Generator[bool, None, None]: timeout=self._timeout, ) - def process(self, items: Iterable[Any]) -> Generator[bool, None, None]: + def process(self, items: Iterable[Any]) -> Iterable[bool]: yield from self.process_upload(items) diff --git a/qdrant_client/uploader/rest_uploader.py b/qdrant_client/uploader/rest_uploader.py index ac2868c1..13ad13ae 100644 --- a/qdrant_client/uploader/rest_uploader.py +++ b/qdrant_client/uploader/rest_uploader.py @@ -1,6 +1,5 @@ -import logging from itertools import count -from typing import Any, Generator, Iterable, Optional, Union +from typing import Any, Iterable, Optional, Union from uuid import uuid4 import numpy as np @@ -81,7 +80,7 @@ def start( raise RuntimeError("Collection name could not be empty") return cls(uri=uri, collection_name=collection_name, max_retries=max_retries, **kwargs) - def process(self, items: Iterable[Any]) -> Generator[bool, None, None]: + def process(self, items: Iterable[Any]) -> Iterable[bool]: for batch in items: yield upload_batch( self.openapi_client, diff --git a/tests/embed_tests/test_local_inference.py b/tests/embed_tests/test_local_inference.py index a3d3b2be..b871d660 100644 --- a/tests/embed_tests/test_local_inference.py +++ b/tests/embed_tests/test_local_inference.py @@ -24,14 +24,6 @@ TEST_IMAGE_PATH = Path(__file__).parent / "misc" / "image.jpeg" -# todo: remove once we don't store models in class variables -@pytest.fixture(autouse=True) -def reset_cls_model_storage(): - QdrantClient.embedding_models = {} - QdrantClient.sparse_embedding_models = {} - QdrantClient.late_interaction_embedding_models = {} - - def arg_interceptor(func, kwarg_storage): kwarg_storage.clear() @@ -96,6 +88,8 @@ def test_upsert(prefer_grpc): sparse_doc_2 = models.Document(text="bye world", model=SPARSE_MODEL_NAME) multi_doc_1 = models.Document(text="hello world", model=COLBERT_MODEL_NAME) multi_doc_2 = models.Document(text="bye world", model=COLBERT_MODEL_NAME) + dense_image_1 = models.Image(image=TEST_IMAGE_PATH, model=DENSE_IMAGE_MODEL_NAME) + dense_image_2 = models.Image(image=TEST_IMAGE_PATH, model=DENSE_IMAGE_MODEL_NAME) # region dense unnamed points = [ @@ -133,37 +127,29 @@ def test_upsert(prefer_grpc): # endregion # region named vectors + vectors_config = { + "text": models.VectorParams(size=DENSE_DIM, distance=models.Distance.COSINE), + "multi-text": models.VectorParams( + size=COLBERT_DIM, + distance=models.Distance.COSINE, + multivector_config=models.MultiVectorConfig( + comparator=models.MultiVectorComparator.MAX_SIM + ), + ), + "image": models.VectorParams(size=DENSE_IMAGE_DIM, distance=models.Distance.COSINE), + } + sparse_vectors_config = { + "sparse-text": models.SparseVectorParams(modifier=models.Modifier.IDF) + } local_client.create_collection( COLLECTION_NAME, - vectors_config={ - "text": models.VectorParams(size=DENSE_DIM, distance=models.Distance.COSINE), - "multi-text": models.VectorParams( - size=COLBERT_DIM, - distance=models.Distance.COSINE, - multivector_config=models.MultiVectorConfig( - comparator=models.MultiVectorComparator.MAX_SIM - ), - ), - }, - sparse_vectors_config={ - "sparse-text": models.SparseVectorParams(modifier=models.Modifier.IDF) - }, + vectors_config=vectors_config, + sparse_vectors_config=sparse_vectors_config, ) remote_client.create_collection( COLLECTION_NAME, - vectors_config={ - "text": models.VectorParams(size=DENSE_DIM, distance=models.Distance.COSINE), - "multi-text": models.VectorParams( - size=COLBERT_DIM, - distance=models.Distance.COSINE, - multivector_config=models.MultiVectorConfig( - comparator=models.MultiVectorComparator.MAX_SIM - ), - ), - }, - sparse_vectors_config={ - "sparse-text": models.SparseVectorParams(modifier=models.Modifier.IDF) - }, + vectors_config=vectors_config, + sparse_vectors_config=sparse_vectors_config, ) points = [ models.PointStruct( @@ -172,6 +158,7 @@ def test_upsert(prefer_grpc): "text": dense_doc_1, "multi-text": multi_doc_1, "sparse-text": sparse_doc_1, + "image": dense_image_1, }, ), models.PointStruct( @@ -180,6 +167,7 @@ def test_upsert(prefer_grpc): "text": dense_doc_2, "multi-text": multi_doc_2, "sparse-text": sparse_doc_2, + "image": dense_image_2, }, ), ] @@ -192,6 +180,7 @@ def test_upsert(prefer_grpc): assert isinstance(vec_point.vector["text"], list) assert isinstance(vec_point.vector["multi-text"], list) assert isinstance(vec_point.vector["sparse-text"], models.SparseVector) + assert isinstance(vec_point.vector["image"], list) compare_collections( local_client, remote_client, num_vectors=10, collection_name=COLLECTION_NAME @@ -203,6 +192,7 @@ def test_upsert(prefer_grpc): "text": [dense_doc_1, dense_doc_2], "multi-text": [multi_doc_1, multi_doc_2], "sparse-text": [sparse_doc_1, sparse_doc_2], + "image": [dense_image_1, dense_image_2], }, ) local_client.upsert(COLLECTION_NAME, batch) @@ -213,6 +203,7 @@ def test_upsert(prefer_grpc): assert isinstance(vectors, dict) assert all([isinstance(vector, list) for vector in vectors["text"]]) assert all([isinstance(vector, list) for vector in vectors["multi-text"]]) + assert all([isinstance(vector, list) for vector in vectors["image"]]) assert all([isinstance(vector, models.SparseVector) for vector in vectors["sparse-text"]]) local_client.delete_collection(COLLECTION_NAME) @@ -220,36 +211,63 @@ def test_upsert(prefer_grpc): # endregion -@pytest.mark.parametrize("prefer_grpc", [True, False]) -def test_batch_update_points(prefer_grpc): +@pytest.mark.parametrize("prefer_grpc", [False]) +def test_upload(prefer_grpc): + def recreate_collection(client, collection_name): + if client.collection_exists(collection_name): + client.delete_collection(collection_name) + vector_params = { + "text": models.VectorParams(size=DENSE_DIM, distance=models.Distance.COSINE), + "image": models.VectorParams(size=DENSE_IMAGE_DIM, distance=models.Distance.COSINE), + } + client.create_collection( + collection_name, + vectors_config=vector_params, + sparse_vectors_config={ + "sparse-text": models.SparseVectorParams(modifier=models.Modifier.IDF) + }, + ) + local_client = QdrantClient(":memory:") if not local_client._FASTEMBED_INSTALLED: pytest.skip("FastEmbed is not installed, skipping") remote_client = QdrantClient(prefer_grpc=prefer_grpc) - local_kwargs = {} - local_client._client.batch_update_points = arg_interceptor( - local_client._client.batch_update_points, local_kwargs - ) dense_doc_1 = models.Document(text="hello world", model=DENSE_MODEL_NAME) dense_doc_2 = models.Document(text="bye world", model=DENSE_MODEL_NAME) + dense_doc_3 = models.Document(text="world world", model=DENSE_MODEL_NAME) + + sparse_doc_1 = models.Document(text="hello world", model=SPARSE_MODEL_NAME) + sparse_doc_2 = models.Document(text="bye world", model=SPARSE_MODEL_NAME) + sparse_doc_3 = models.Document(text="world world", model=SPARSE_MODEL_NAME) + + dense_image_1 = models.Image(image=TEST_IMAGE_PATH, model=DENSE_IMAGE_MODEL_NAME) + dense_image_2 = models.Image(image=TEST_IMAGE_PATH, model=DENSE_IMAGE_MODEL_NAME) + dense_image_3 = models.Image(image=TEST_IMAGE_PATH, model=DENSE_IMAGE_MODEL_NAME) - # region unnamed points = [ - models.PointStruct(id=1, vector=dense_doc_1), - models.PointStruct(id=2, vector=dense_doc_2), + models.PointStruct( + id=1, vector={"text": dense_doc_1, "image": dense_image_1, "sparse-text": sparse_doc_1} + ), + models.PointStruct( + id=2, vector={"text": dense_doc_2, "image": dense_image_2, "sparse-text": sparse_doc_2} + ), + models.PointStruct( + id=3, vector={"text": dense_doc_3, "image": dense_image_3, "sparse-text": sparse_doc_3} + ), ] - populate_dense_collection(local_client, points) - populate_dense_collection(remote_client, points) + recreate_collection(local_client, COLLECTION_NAME) + recreate_collection(remote_client, COLLECTION_NAME) - batch = models.Batch(ids=[2, 3], vectors=[dense_doc_1, dense_doc_2]) - upsert_operation = models.UpsertOperation(upsert=models.PointsBatch(batch=batch)) - local_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) - remote_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) - current_operation = local_kwargs["update_operations"][0] - current_batch = current_operation.upsert.batch - assert all([isinstance(vector, list) for vector in current_batch.vectors]) + local_client.upload_points(COLLECTION_NAME, points) + remote_client.upload_points(COLLECTION_NAME, points, wait=True) + + assert local_client.count(COLLECTION_NAME).count == len(points) + assert isinstance( + local_client.retrieve(COLLECTION_NAME, ids=[1], with_vectors=True)[0].vector["text"], list + ) # assert doc + # has been substituted with its embedding compare_collections( local_client, @@ -258,60 +276,42 @@ def test_batch_update_points(prefer_grpc): collection_name=COLLECTION_NAME, ) - new_points = [ - models.PointStruct(id=3, vector=dense_doc_1), - models.PointStruct(id=4, vector=dense_doc_2), + recreate_collection(local_client, COLLECTION_NAME) + recreate_collection(remote_client, COLLECTION_NAME) + + vectors = [ + {"text": dense_doc_1, "image": dense_image_1, "sparse-text": sparse_doc_1}, + {"text": dense_doc_2, "image": dense_image_2, "sparse-text": sparse_doc_2}, + {"text": dense_doc_3, "image": dense_image_3, "sparse-text": sparse_doc_3}, ] - upsert_operation = models.UpsertOperation(upsert=models.PointsList(points=new_points)) - local_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) - remote_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) - current_operation = local_kwargs["update_operations"][0] - current_batch = current_operation.upsert.points - assert all([isinstance(vector.vector, list) for vector in current_batch]) + ids = list(range(len(vectors))) + local_client.upload_collection(COLLECTION_NAME, ids=ids, vectors=vectors) + remote_client.upload_collection(COLLECTION_NAME, ids=ids, vectors=vectors, wait=True) - update_vectors_operation = models.UpdateVectorsOperation( - update_vectors=models.UpdateVectors(points=[models.PointVectors(id=1, vector=dense_doc_2)]) - ) - upsert_operation = models.UpsertOperation( - upsert=models.PointsList(points=[models.PointStruct(id=5, vector=dense_doc_2)]) - ) - local_client.batch_update_points(COLLECTION_NAME, [update_vectors_operation, upsert_operation]) - remote_client.batch_update_points( - COLLECTION_NAME, [update_vectors_operation, upsert_operation] - ) - current_update_operation = local_kwargs["update_operations"][0] - current_upsert_operation = local_kwargs["update_operations"][1] + assert local_client.count(COLLECTION_NAME).count == len(vectors) + assert isinstance( + local_client.retrieve(COLLECTION_NAME, ids=[1], with_vectors=True)[0].vector["text"], list + ) # assert doc + # has been substituted with its embedding - assert all( - [ - isinstance(vector.vector, list) - for vector in current_update_operation.update_vectors.points - ] - ) - assert all( - [isinstance(vector.vector, list) for vector in current_upsert_operation.upsert.points] + compare_collections( + local_client, + remote_client, + num_vectors=10, + collection_name=COLLECTION_NAME, ) - local_client.delete_collection(COLLECTION_NAME) - remote_client.delete_collection(COLLECTION_NAME) - # endregion - - # region named - points = [ - models.PointStruct(id=1, vector={"text": dense_doc_1}), - models.PointStruct(id=2, vector={"text": dense_doc_2}), - ] + recreate_collection(local_client, COLLECTION_NAME) + recreate_collection(remote_client, COLLECTION_NAME) - populate_dense_collection(local_client, points, vector_name="text") - populate_dense_collection(remote_client, points, vector_name="text") + local_client.upload_points(COLLECTION_NAME, points, parallel=2, batch_size=2) + remote_client.upload_points(COLLECTION_NAME, points, parallel=2, batch_size=2, wait=True) - batch = models.Batch(ids=[2, 3], vectors={"text": [dense_doc_1, dense_doc_2]}) - upsert_operation = models.UpsertOperation(upsert=models.PointsBatch(batch=batch)) - local_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) - remote_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) - current_operation = local_kwargs["update_operations"][0] - current_batch = current_operation.upsert.batch - assert all([isinstance(vector, list) for vector in current_batch.vectors.values()]) + assert local_client.count(COLLECTION_NAME).count == len(points) + assert isinstance( + local_client.retrieve(COLLECTION_NAME, ids=[1], with_vectors=True)[0].vector["text"], list + ) # assert doc + # has been substituted with its embedding compare_collections( local_client, @@ -320,85 +320,42 @@ def test_batch_update_points(prefer_grpc): collection_name=COLLECTION_NAME, ) - new_points = [ - models.PointStruct(id=3, vector={"text": dense_doc_1}), - models.PointStruct(id=4, vector={"text": dense_doc_2}), - ] - upsert_operation = models.UpsertOperation(upsert=models.PointsList(points=new_points)) - local_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) - remote_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) - current_operation = local_kwargs["update_operations"][0] - current_batch = current_operation.upsert.points - assert all([isinstance(vector.vector["text"], list) for vector in current_batch]) - - update_vectors_operation = models.UpdateVectorsOperation( - update_vectors=models.UpdateVectors( - points=[models.PointVectors(id=1, vector={"text": dense_doc_2})] - ) - ) - upsert_operation = models.UpsertOperation( - upsert=models.PointsList(points=[models.PointStruct(id=5, vector={"text": dense_doc_2})]) - ) - local_client.batch_update_points(COLLECTION_NAME, [update_vectors_operation, upsert_operation]) - remote_client.batch_update_points( - COLLECTION_NAME, [update_vectors_operation, upsert_operation] - ) - current_update_operation = local_kwargs["update_operations"][0] - current_upsert_operation = local_kwargs["update_operations"][1] + recreate_collection(local_client, COLLECTION_NAME) + recreate_collection(remote_client, COLLECTION_NAME) - assert all( - [ - isinstance(vector.vector["text"], list) - for vector in current_update_operation.update_vectors.points - ] + local_client.upload_collection( + COLLECTION_NAME, ids=ids, vectors=vectors, parallel=2, batch_size=2 ) - assert all( - [ - isinstance(vector.vector["text"], list) - for vector in current_upsert_operation.upsert.points - ] + remote_client.upload_collection( + COLLECTION_NAME, ids=ids, vectors=vectors, parallel=2, batch_size=2, wait=True ) - local_client.delete_collection(COLLECTION_NAME) - remote_client.delete_collection(COLLECTION_NAME) - # endregion - + assert local_client.count(COLLECTION_NAME).count == len(vectors) + assert isinstance( + local_client.retrieve(COLLECTION_NAME, ids=[1], with_vectors=True)[0].vector["text"], list + ) # assert doc + # has been substituted with its embedding -@pytest.mark.parametrize("prefer_grpc", [True, False]) -def test_update_vectors(prefer_grpc): - local_client = QdrantClient(":memory:") - if not local_client._FASTEMBED_INSTALLED: - pytest.skip("FastEmbed is not installed, skipping") - remote_client = QdrantClient(prefer_grpc=prefer_grpc) - local_kwargs = {} - local_client._client.update_vectors = arg_interceptor( - local_client._client.update_vectors, local_kwargs + compare_collections( + local_client, + remote_client, + num_vectors=10, + collection_name=COLLECTION_NAME, ) - dense_doc_1 = models.Document( - text="hello world", - model=DENSE_MODEL_NAME, - ) - dense_doc_2 = models.Document(text="bye world", model=DENSE_MODEL_NAME) - dense_doc_3 = models.Document(text="goodbye world", model=DENSE_MODEL_NAME) - # region unnamed - points = [ - models.PointStruct(id=1, vector=dense_doc_1), - models.PointStruct(id=2, vector=dense_doc_2), - ] + assert isinstance(points[0].vector["text"], models.Document) - populate_dense_collection(local_client, points) - populate_dense_collection(remote_client, points) + recreate_collection(local_client, COLLECTION_NAME) + recreate_collection(remote_client, COLLECTION_NAME) - point_vectors = [ - models.PointVectors(id=1, vector=dense_doc_2), - models.PointVectors(id=2, vector=dense_doc_3), - ] + local_client.upload_points(COLLECTION_NAME, iter(points), parallel=2, batch_size=2) + remote_client.upload_points(COLLECTION_NAME, iter(points), parallel=2, batch_size=2, wait=True) - local_client.update_vectors(COLLECTION_NAME, point_vectors) - remote_client.update_vectors(COLLECTION_NAME, point_vectors) - current_vectors = local_kwargs["points"] - assert all([isinstance(vector.vector, list) for vector in current_vectors]) + assert local_client.count(COLLECTION_NAME).count == len(points) + assert isinstance( + local_client.retrieve(COLLECTION_NAME, ids=[1], with_vectors=True)[0].vector["text"], list + ) # assert doc + # has been substituted with its embedding compare_collections( local_client, @@ -407,28 +364,23 @@ def test_update_vectors(prefer_grpc): collection_name=COLLECTION_NAME, ) - local_client.delete_collection(COLLECTION_NAME) - remote_client.delete_collection(COLLECTION_NAME) - # endregion - - # region named - points = [ - models.PointStruct(id=1, vector={"text": dense_doc_1}), - models.PointStruct(id=2, vector={"text": dense_doc_2}), - ] + assert isinstance(vectors[0]["text"], models.Document) - populate_dense_collection(local_client, points, vector_name="text") - populate_dense_collection(remote_client, points, vector_name="text") + recreate_collection(local_client, COLLECTION_NAME) + recreate_collection(remote_client, COLLECTION_NAME) - point_vectors = [ - models.PointVectors(id=1, vector={"text": dense_doc_2}), - models.PointVectors(id=2, vector={"text": dense_doc_3}), - ] + local_client.upload_collection( + COLLECTION_NAME, ids=ids, vectors=iter(vectors), parallel=2, batch_size=2 + ) + remote_client.upload_collection( + COLLECTION_NAME, ids=ids, vectors=iter(vectors), parallel=2, batch_size=2, wait=True + ) - local_client.update_vectors(COLLECTION_NAME, point_vectors) - remote_client.update_vectors(COLLECTION_NAME, point_vectors) - current_vectors = local_kwargs["points"] - assert all([isinstance(vector.vector["text"], list) for vector in current_vectors]) + assert local_client.count(COLLECTION_NAME).count == len(vectors) + assert isinstance( + local_client.retrieve(COLLECTION_NAME, ids=[1], with_vectors=True)[0].vector["text"], list + ) # assert doc + # has been substituted with its embedding compare_collections( local_client, @@ -437,13 +389,9 @@ def test_update_vectors(prefer_grpc): collection_name=COLLECTION_NAME, ) - local_client.delete_collection(COLLECTION_NAME) - remote_client.delete_collection(COLLECTION_NAME) - # endregion - @pytest.mark.parametrize("prefer_grpc", [True, False]) -def test_query_points_and_query_points_groups(prefer_grpc): +def test_query_points(prefer_grpc): local_client = QdrantClient(":memory:") if not local_client._FASTEMBED_INSTALLED: pytest.skip("FastEmbed is not installed, skipping") @@ -452,100 +400,48 @@ def test_query_points_and_query_points_groups(prefer_grpc): local_client._client.query_points = arg_interceptor( local_client._client.query_points, local_kwargs ) - local_client._client.query_points_groups = arg_interceptor( - local_client._client.query_points_groups, local_kwargs - ) - sparse_doc_1 = models.Document(text="hello world", model=SPARSE_MODEL_NAME) sparse_doc_2 = models.Document(text="bye world", model=SPARSE_MODEL_NAME) sparse_doc_3 = models.Document(text="goodbye world", model=SPARSE_MODEL_NAME) sparse_doc_4 = models.Document(text="good afternoon world", model=SPARSE_MODEL_NAME) sparse_doc_5 = models.Document(text="good morning world", model=SPARSE_MODEL_NAME) - points = [ - models.PointStruct(id=i, vector={"text": doc}, payload={"content": doc.text}) + models.PointStruct(id=i, vector={"sparse-text": doc}, payload={"content": doc.text}) for i, doc in enumerate( - [sparse_doc_1, sparse_doc_2, sparse_doc_3, sparse_doc_4, sparse_doc_5] + [sparse_doc_1, sparse_doc_2, sparse_doc_3, sparse_doc_4, sparse_doc_5], ) ] - populate_sparse_collection(local_client, points, vector_name="text") - populate_sparse_collection(remote_client, points, vector_name="text") + populate_sparse_collection(local_client, points, vector_name="sparse-text") + populate_sparse_collection(remote_client, points, vector_name="sparse-text") + # region non-prefetch queries + local_client.query_points(COLLECTION_NAME, sparse_doc_1, using="sparse-text") + remote_client.query_points(COLLECTION_NAME, sparse_doc_1, using="sparse-text") + current_query = local_kwargs["query"] + assert isinstance(current_query.nearest, models.SparseVector) + # retrieved_point_id_0 = local_client.retrieve(COLLECTION_NAME, ids=[0], with_vectors=True)[0] + # # assert that we generate different embeddings for doc and query - # region query_points_groups - local_client.query_points_groups( - COLLECTION_NAME, group_by="content", query=sparse_doc_1, using="text" - ) - remote_client.query_points_groups( - COLLECTION_NAME, group_by="content", query=sparse_doc_1, using="text" - ) + nearest_query = models.NearestQuery(nearest=sparse_doc_1) + local_client.query_points(COLLECTION_NAME, nearest_query, using="sparse-text") + remote_client.query_points(COLLECTION_NAME, nearest_query, using="sparse-text") current_query = local_kwargs["query"] assert isinstance(current_query.nearest, models.SparseVector) - doc_point = local_client.retrieve(COLLECTION_NAME, ids=[0], with_vectors=True)[0] - # assert that we generate different embeddings for doc and query - assert not ( - np.allclose(doc_point.vector["text"].values, current_query.nearest.values, atol=1e-3) - ) - prefetch_1 = models.Prefetch( - query=models.NearestQuery(nearest=sparse_doc_2), using="text", limit=3 + recommend_query = models.RecommendQuery( + recommend=models.RecommendInput( + positive=[sparse_doc_1], + negative=[sparse_doc_1], + ) ) - prefetch_2 = models.Prefetch( - query=models.NearestQuery(nearest=sparse_doc_3), using="text", limit=3 - ) - - local_client.query_points_groups( - COLLECTION_NAME, - group_by="content", - query=sparse_doc_1, - prefetch=[prefetch_1, prefetch_2], - using="text", - ) - remote_client.query_points_groups( - COLLECTION_NAME, - group_by="content", - query=sparse_doc_1, - prefetch=[prefetch_1, prefetch_2], - using="text", - ) - current_query = local_kwargs["query"] - current_prefetch = local_kwargs["prefetch"] - assert isinstance(current_query.nearest, models.SparseVector) - assert isinstance(current_prefetch[0].query.nearest, models.SparseVector) - assert isinstance(current_prefetch[1].query.nearest, models.SparseVector) - # endregion - - # region non-prefetch queries - local_client.query_points(COLLECTION_NAME, sparse_doc_1, using="text") - remote_client.query_points(COLLECTION_NAME, sparse_doc_1, using="text") - current_query = local_kwargs["query"] - assert isinstance(current_query.nearest, models.SparseVector) - doc_point = local_client.retrieve(COLLECTION_NAME, ids=[0], with_vectors=True)[0] - # assert that we generate different embeddings for doc and query - assert not ( - np.allclose(doc_point.vector["text"].values, current_query.nearest.values, atol=1e-3) - ) - - nearest_query = models.NearestQuery(nearest=sparse_doc_1) - local_client.query_points(COLLECTION_NAME, nearest_query, using="text") - remote_client.query_points(COLLECTION_NAME, nearest_query, using="text") - current_query = local_kwargs["query"] - assert isinstance(current_query.nearest, models.SparseVector) - - recommend_query = models.RecommendQuery( - recommend=models.RecommendInput( - positive=[sparse_doc_1], - negative=[sparse_doc_1], - ) - ) - local_client.query_points(COLLECTION_NAME, recommend_query, using="text") - remote_client.query_points(COLLECTION_NAME, recommend_query, using="text") - current_query = local_kwargs["query"] - assert all( - isinstance(vector, models.SparseVector) for vector in current_query.recommend.positive - ) - assert all( - isinstance(vector, models.SparseVector) for vector in current_query.recommend.negative + local_client.query_points(COLLECTION_NAME, recommend_query, using="sparse-text") + remote_client.query_points(COLLECTION_NAME, recommend_query, using="sparse-text") + current_query = local_kwargs["query"] + assert all( + isinstance(vector, models.SparseVector) for vector in current_query.recommend.positive + ) + assert all( + isinstance(vector, models.SparseVector) for vector in current_query.recommend.negative ) discover_query = models.DiscoverQuery( @@ -557,8 +453,8 @@ def test_query_points_and_query_points_groups(prefer_grpc): ), ) ) - local_client.query_points(COLLECTION_NAME, discover_query, using="text") - remote_client.query_points(COLLECTION_NAME, discover_query, using="text") + local_client.query_points(COLLECTION_NAME, discover_query, using="sparse-text") + remote_client.query_points(COLLECTION_NAME, discover_query, using="sparse-text") current_query = local_kwargs["query"] assert isinstance(current_query.discover.target, models.SparseVector) context_pair = current_query.discover.context @@ -576,8 +472,8 @@ def test_query_points_and_query_points_groups(prefer_grpc): ], ) ) - local_client.query_points(COLLECTION_NAME, discover_query_list, using="text") - remote_client.query_points(COLLECTION_NAME, discover_query_list, using="text") + local_client.query_points(COLLECTION_NAME, discover_query_list, using="sparse-text") + remote_client.query_points(COLLECTION_NAME, discover_query_list, using="sparse-text") current_query = local_kwargs["query"] assert isinstance(current_query.discover.target, models.SparseVector) context_pairs = current_query.discover.context @@ -590,8 +486,8 @@ def test_query_points_and_query_points_groups(prefer_grpc): negative=sparse_doc_2, ) ) - local_client.query_points(COLLECTION_NAME, context_query, using="text") - remote_client.query_points(COLLECTION_NAME, context_query, using="text") + local_client.query_points(COLLECTION_NAME, context_query, using="sparse-text") + remote_client.query_points(COLLECTION_NAME, context_query, using="sparse-text") current_query = local_kwargs["query"] context = current_query.context assert isinstance(context.positive, models.SparseVector) @@ -609,8 +505,8 @@ def test_query_points_and_query_points_groups(prefer_grpc): ), ] ) - local_client.query_points(COLLECTION_NAME, context_query_list, using="text") - remote_client.query_points(COLLECTION_NAME, context_query_list, using="text") + local_client.query_points(COLLECTION_NAME, context_query_list, using="sparse-text") + remote_client.query_points(COLLECTION_NAME, context_query_list, using="sparse-text") current_query = local_kwargs["query"] contexts = current_query.context assert all(isinstance(context.positive, models.SparseVector) for context in contexts) @@ -625,23 +521,23 @@ def test_query_points_and_query_points_groups(prefer_grpc): prefetch=models.Prefetch( query=nearest_query, prefetch=[ - models.Prefetch(query=discover_query_list, limit=5, using="text"), - models.Prefetch(query=nearest_query, using="text", limit=5), + models.Prefetch(query=discover_query_list, limit=5, using="sparse-text"), + models.Prefetch(query=nearest_query, using="sparse-text", limit=5), ], - using="text", + using="sparse-text", limit=4, ), - using="text", + using="sparse-text", limit=3, ), - using="text", + using="sparse-text", limit=2, ) local_client.query_points( - COLLECTION_NAME, query=nearest_query, prefetch=prefetch, limit=1, using="text" + COLLECTION_NAME, query=nearest_query, prefetch=prefetch, limit=1, using="sparse-text" ) remote_client.query_points( - COLLECTION_NAME, query=nearest_query, prefetch=prefetch, limit=1, using="text" + COLLECTION_NAME, query=nearest_query, prefetch=prefetch, limit=1, using="sparse-text" ) current_query = local_kwargs["query"] current_prefetch = local_kwargs["prefetch"] @@ -666,6 +562,204 @@ def test_query_points_and_query_points_groups(prefer_grpc): remote_client.delete_collection(COLLECTION_NAME) +@pytest.mark.parametrize("prefer_grpc", [True, False]) +def test_query_points_is_query(prefer_grpc): + # dense_model_name = "jinaai/jina-embeddings-v3" + # dense_dim = 1024 + + local_client = QdrantClient(":memory:") + if not local_client._FASTEMBED_INSTALLED: + pytest.skip("FastEmbed is not installed, skipping") + remote_client = QdrantClient(prefer_grpc=prefer_grpc) + local_kwargs = {} + local_client._client.query_points = arg_interceptor( + local_client._client.query_points, local_kwargs + ) + # dense_doc_1 = models.Document(text="hello world", model=dense_model_name) # todo: uncomment once this model is supported + sparse_doc_1 = models.Document(text="hello world", model=SPARSE_MODEL_NAME) + colbert_doc_1 = models.Document(text="hello world", model=COLBERT_MODEL_NAME) + + vectors_config = { + # "dense-text": models.VectorParams(size=dense_dim, distance=models.Distance.COSINE), + "colbert-text": models.VectorParams( + size=COLBERT_DIM, + distance=models.Distance.COSINE, + multivector_config=models.MultiVectorConfig( + comparator=models.MultiVectorComparator.MAX_SIM + ), + ), + } + sparse_vectors_config = { + "sparse-text": models.SparseVectorParams(modifier=models.Modifier.IDF) + } + + local_client.create_collection( + COLLECTION_NAME, + vectors_config=vectors_config, + sparse_vectors_config=sparse_vectors_config, + ) + if remote_client.collection_exists(COLLECTION_NAME): + remote_client.delete_collection(COLLECTION_NAME) + remote_client.create_collection( + COLLECTION_NAME, + vectors_config=vectors_config, + sparse_vectors_config=sparse_vectors_config, + ) + + points = [ + models.PointStruct( + id=0, vector={"colbert-text": colbert_doc_1, "sparse-text": sparse_doc_1} + ) + ] + local_client.upsert(COLLECTION_NAME, points) + remote_client.upsert(COLLECTION_NAME, points) + + retrieved_point = local_client.retrieve(COLLECTION_NAME, ids=[0], with_vectors=True)[0] + + # local_client.query_points(COLLECTION_NAME, dense_doc_1, using="dense-text") + # remote_client.query_points(COLLECTION_NAME, dense_doc_1, using="dense-text") + # + # assert isinstance(local_kwargs["query"].nearest, list) + # assert not np.allclose(retrieved_point.vector["dense-text"], local_kwargs["query"].nearest, atol=1e-3) + + local_client.query_points(COLLECTION_NAME, sparse_doc_1, using="sparse-text") + remote_client.query_points(COLLECTION_NAME, sparse_doc_1, using="sparse-text") + + assert isinstance(local_kwargs["query"].nearest, models.SparseVector) + assert not np.allclose( + retrieved_point.vector["sparse-text"].values, + local_kwargs["query"].nearest.values, + atol=1e-3, + ) + + local_client.query_points(COLLECTION_NAME, colbert_doc_1, using="colbert-text") + remote_client.query_points(COLLECTION_NAME, colbert_doc_1, using="colbert-text") + + assert isinstance(local_kwargs["query"].nearest, list) + # colbert has a min number of 32 tokens for query + assert len(retrieved_point.vector["colbert-text"]) != len(local_kwargs["query"].nearest) + + local_client.delete_collection(COLLECTION_NAME) + remote_client.delete_collection(COLLECTION_NAME) + + +@pytest.mark.parametrize("prefer_grpc", [True, False]) +def test_query_points_groups(prefer_grpc): + local_client = QdrantClient(":memory:") + if not local_client._FASTEMBED_INSTALLED: + pytest.skip("FastEmbed is not installed, skipping") + remote_client = QdrantClient(prefer_grpc=prefer_grpc) + local_kwargs = {} + local_client._client.query_points_groups = arg_interceptor( + local_client._client.query_points_groups, local_kwargs + ) + sparse_doc_1 = models.Document(text="hello world", model=SPARSE_MODEL_NAME) + sparse_doc_2 = models.Document(text="bye world", model=SPARSE_MODEL_NAME) + sparse_doc_3 = models.Document(text="goodbye world", model=SPARSE_MODEL_NAME) + sparse_doc_4 = models.Document(text="good afternoon world", model=SPARSE_MODEL_NAME) + sparse_doc_5 = models.Document(text="good morning world", model=SPARSE_MODEL_NAME) + points = [ + models.PointStruct(id=i, vector={"sparse-text": doc}, payload={"content": doc.text}) + for i, doc in enumerate( + [sparse_doc_1, sparse_doc_2, sparse_doc_3, sparse_doc_4, sparse_doc_5], + ) + ] + + populate_sparse_collection(local_client, points, vector_name="sparse-text") + populate_sparse_collection(remote_client, points, vector_name="sparse-text") + # region query_points_groups + local_client.query_points_groups( + COLLECTION_NAME, group_by="content", query=sparse_doc_1, using="sparse-text" + ) + remote_client.query_points_groups( + COLLECTION_NAME, group_by="content", query=sparse_doc_1, using="sparse-text" + ) + current_query = local_kwargs["query"] + assert isinstance(current_query.nearest, models.SparseVector) + retrieved_point_id_0 = local_client.retrieve(COLLECTION_NAME, ids=[0], with_vectors=True)[0] + # assert that we generate different embeddings for doc and query + # we are using sparse_doc_1 as a query + assert not ( + np.allclose( + retrieved_point_id_0.vector["sparse-text"].values, + current_query.nearest.values, + atol=1e-3, + ) + ) + + prefetch_1 = models.Prefetch( + query=models.NearestQuery(nearest=sparse_doc_2), using="sparse-text", limit=3 + ) + prefetch_2 = models.Prefetch( + query=models.NearestQuery(nearest=sparse_doc_3), using="sparse-text", limit=3 + ) + + local_client.query_points_groups( + COLLECTION_NAME, + group_by="content", + query=sparse_doc_1, + prefetch=[prefetch_1, prefetch_2], + using="sparse-text", + ) + remote_client.query_points_groups( + COLLECTION_NAME, + group_by="content", + query=sparse_doc_1, + prefetch=[prefetch_1, prefetch_2], + using="sparse-text", + ) + current_query = local_kwargs["query"] + current_prefetch = local_kwargs["prefetch"] + assert isinstance(current_query.nearest, models.SparseVector) + assert isinstance(current_prefetch[0].query.nearest, models.SparseVector) + assert isinstance(current_prefetch[1].query.nearest, models.SparseVector) + assert not ( + np.allclose( + retrieved_point_id_0.vector["sparse-text"].values, + current_query.nearest.values, + atol=1e-3, + ) + ) + retrieved_point_id_1 = local_client.retrieve(COLLECTION_NAME, ids=[1], with_vectors=True)[0] + assert not ( + np.allclose( + retrieved_point_id_1.vector["sparse-text"].values, + current_prefetch[0].query.nearest.values, + atol=1e-3, + ) + ) + + assert isinstance(prefetch_1.query.nearest, models.Document) + local_kwargs.clear() + local_client.query_points_groups( + COLLECTION_NAME, + group_by="content", + query=sparse_doc_1, + prefetch=prefetch_1, + using="sparse-text", + ) + remote_client.query_points_groups( + COLLECTION_NAME, + group_by="content", + query=sparse_doc_1, + prefetch=prefetch_1, + using="sparse-text", + ) + current_prefetch = local_kwargs["prefetch"] + assert isinstance(current_prefetch.query.nearest, models.SparseVector) + assert not ( + np.allclose( + retrieved_point_id_1.vector["sparse-text"].values, + current_prefetch.query.nearest.values, + atol=1e-3, + ) + ) + # endregion + + local_client.delete_collection(COLLECTION_NAME) + remote_client.delete_collection(COLLECTION_NAME) + + @pytest.mark.parametrize("prefer_grpc", [True, False]) def test_query_batch_points(prefer_grpc): local_client = QdrantClient(":memory:") @@ -677,42 +771,281 @@ def test_query_batch_points(prefer_grpc): local_client._client.query_batch_points, local_kwargs ) - dense_doc_1 = models.Document(text="hello world", model=DENSE_MODEL_NAME) - dense_doc_2 = models.Document(text="bye world", model=DENSE_MODEL_NAME) - dense_doc_3 = models.Document(text="goodbye world", model=DENSE_MODEL_NAME) - dense_doc_4 = models.Document(text="good afternoon world", model=DENSE_MODEL_NAME) - dense_doc_5 = models.Document(text="good morning world", model=DENSE_MODEL_NAME) + sparse_doc_1 = models.Document(text="hello world", model=SPARSE_MODEL_NAME) + sparse_doc_2 = models.Document(text="bye world", model=SPARSE_MODEL_NAME) + sparse_doc_3 = models.Document(text="goodbye world", model=SPARSE_MODEL_NAME) + sparse_doc_4 = models.Document(text="good afternoon world", model=SPARSE_MODEL_NAME) + sparse_doc_5 = models.Document(text="good morning world", model=SPARSE_MODEL_NAME) points = [ - models.PointStruct(id=i, vector=dense_doc) + models.PointStruct(id=i, vector={"sparse-text": dense_doc}) for i, dense_doc in enumerate( - [dense_doc_1, dense_doc_2, dense_doc_3, dense_doc_4, dense_doc_5] + [sparse_doc_1, sparse_doc_2, sparse_doc_3, sparse_doc_4, sparse_doc_5] ) ] - populate_dense_collection(local_client, points) - populate_dense_collection(remote_client, points) + populate_sparse_collection(local_client, points, vector_name="sparse-text") + populate_sparse_collection(remote_client, points, vector_name="sparse-text") - prefetch_1 = models.Prefetch(query=models.NearestQuery(nearest=dense_doc_2), limit=3) - prefetch_2 = models.Prefetch(query=models.NearestQuery(nearest=dense_doc_3), limit=3) + prefetch_1 = models.Prefetch( + query=models.NearestQuery(nearest=sparse_doc_2), limit=3, using="sparse-text" + ) + prefetch_2 = models.Prefetch( + query=models.NearestQuery(nearest=sparse_doc_3), limit=3, using="sparse-text" + ) query_requests = [ - models.QueryRequest(query=models.NearestQuery(nearest=dense_doc_1)), + models.QueryRequest(query=models.NearestQuery(nearest=sparse_doc_1), using="sparse-text"), models.QueryRequest( - query=models.NearestQuery(nearest=dense_doc_2), prefetch=[prefetch_1, prefetch_2] + query=models.NearestQuery(nearest=sparse_doc_2), + prefetch=[prefetch_1, prefetch_2], + using="sparse-text", ), ] - local_client.query_batch_points(COLLECTION_NAME, query_requests) - remote_client.query_batch_points(COLLECTION_NAME, query_requests) - current_requests = local_kwargs["requests"] - assert all([isinstance(request.query.nearest, list) for request in current_requests]) - assert all( - [isinstance(prefetch.query.nearest, list) for prefetch in current_requests[1].prefetch] + local_client.query_batch_points(COLLECTION_NAME, query_requests) + remote_client.query_batch_points(COLLECTION_NAME, query_requests) + current_requests = local_kwargs["requests"] + assert all( + [isinstance(request.query.nearest, models.SparseVector) for request in current_requests] + ) + assert all( + [ + isinstance(prefetch.query.nearest, models.SparseVector) + for prefetch in current_requests[1].prefetch + ] + ) + + retrieved_point = local_client.retrieve(COLLECTION_NAME, ids=[0], with_vectors=True)[0] + assert not np.allclose( + retrieved_point.vector["sparse-text"].values, + current_requests[0].query.nearest.values, + atol=1e-3, + ) + local_client.delete_collection(COLLECTION_NAME) + remote_client.delete_collection(COLLECTION_NAME) + + +@pytest.mark.parametrize("prefer_grpc", [True, False]) +def test_batch_update_points(prefer_grpc): + local_client = QdrantClient(":memory:") + if not local_client._FASTEMBED_INSTALLED: + pytest.skip("FastEmbed is not installed, skipping") + remote_client = QdrantClient(prefer_grpc=prefer_grpc) + local_kwargs = {} + local_client._client.batch_update_points = arg_interceptor( + local_client._client.batch_update_points, local_kwargs + ) + + dense_doc_1 = models.Document(text="hello world", model=DENSE_MODEL_NAME) + dense_doc_2 = models.Document(text="bye world", model=DENSE_MODEL_NAME) + + # region unnamed + points = [ + models.PointStruct(id=1, vector=dense_doc_1), + models.PointStruct(id=2, vector=dense_doc_2), + ] + + populate_dense_collection(local_client, points) + populate_dense_collection(remote_client, points) + + batch = models.Batch(ids=[2, 3], vectors=[dense_doc_1, dense_doc_2]) + upsert_operation = models.UpsertOperation(upsert=models.PointsBatch(batch=batch)) + local_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) + remote_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) + current_operation = local_kwargs["update_operations"][0] + current_batch = current_operation.upsert.batch + assert all([isinstance(vector, list) for vector in current_batch.vectors]) + + compare_collections( + local_client, + remote_client, + num_vectors=10, + collection_name=COLLECTION_NAME, + ) + + new_points = [ + models.PointStruct(id=3, vector=dense_doc_1), + models.PointStruct(id=4, vector=dense_doc_2), + ] + upsert_operation = models.UpsertOperation(upsert=models.PointsList(points=new_points)) + local_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) + remote_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) + current_operation = local_kwargs["update_operations"][0] + current_batch = current_operation.upsert.points + assert all([isinstance(vector.vector, list) for vector in current_batch]) + + update_vectors_operation = models.UpdateVectorsOperation( + update_vectors=models.UpdateVectors(points=[models.PointVectors(id=1, vector=dense_doc_2)]) + ) + upsert_operation = models.UpsertOperation( + upsert=models.PointsList(points=[models.PointStruct(id=5, vector=dense_doc_2)]) + ) + local_client.batch_update_points(COLLECTION_NAME, [update_vectors_operation, upsert_operation]) + remote_client.batch_update_points( + COLLECTION_NAME, [update_vectors_operation, upsert_operation] + ) + current_update_operation = local_kwargs["update_operations"][0] + current_upsert_operation = local_kwargs["update_operations"][1] + + assert all( + [ + isinstance(vector.vector, list) + for vector in current_update_operation.update_vectors.points + ] + ) + assert all( + [isinstance(vector.vector, list) for vector in current_upsert_operation.upsert.points] + ) + + local_client.delete_collection(COLLECTION_NAME) + remote_client.delete_collection(COLLECTION_NAME) + # endregion + + # region named + points = [ + models.PointStruct(id=1, vector={"text": dense_doc_1}), + models.PointStruct(id=2, vector={"text": dense_doc_2}), + ] + + populate_dense_collection(local_client, points, vector_name="text") + populate_dense_collection(remote_client, points, vector_name="text") + + batch = models.Batch(ids=[2, 3], vectors={"text": [dense_doc_1, dense_doc_2]}) + upsert_operation = models.UpsertOperation(upsert=models.PointsBatch(batch=batch)) + local_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) + remote_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) + current_operation = local_kwargs["update_operations"][0] + current_batch = current_operation.upsert.batch + assert all([isinstance(vector, list) for vector in current_batch.vectors.values()]) + + compare_collections( + local_client, + remote_client, + num_vectors=10, + collection_name=COLLECTION_NAME, + ) + + new_points = [ + models.PointStruct(id=3, vector={"text": dense_doc_1}), + models.PointStruct(id=4, vector={"text": dense_doc_2}), + ] + upsert_operation = models.UpsertOperation(upsert=models.PointsList(points=new_points)) + local_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) + remote_client.batch_update_points(COLLECTION_NAME, [upsert_operation]) + current_operation = local_kwargs["update_operations"][0] + current_batch = current_operation.upsert.points + assert all([isinstance(vector.vector["text"], list) for vector in current_batch]) + + update_vectors_operation = models.UpdateVectorsOperation( + update_vectors=models.UpdateVectors( + points=[models.PointVectors(id=1, vector={"text": dense_doc_2})] + ) + ) + upsert_operation = models.UpsertOperation( + upsert=models.PointsList(points=[models.PointStruct(id=5, vector={"text": dense_doc_2})]) + ) + local_client.batch_update_points(COLLECTION_NAME, [update_vectors_operation, upsert_operation]) + remote_client.batch_update_points( + COLLECTION_NAME, [update_vectors_operation, upsert_operation] + ) + current_update_operation = local_kwargs["update_operations"][0] + current_upsert_operation = local_kwargs["update_operations"][1] + + assert all( + [ + isinstance(vector.vector["text"], list) + for vector in current_update_operation.update_vectors.points + ] + ) + assert all( + [ + isinstance(vector.vector["text"], list) + for vector in current_upsert_operation.upsert.points + ] + ) + + local_client.delete_collection(COLLECTION_NAME) + remote_client.delete_collection(COLLECTION_NAME) + # endregion + + +@pytest.mark.parametrize("prefer_grpc", [True, False]) +def test_update_vectors(prefer_grpc): + local_client = QdrantClient(":memory:") + if not local_client._FASTEMBED_INSTALLED: + pytest.skip("FastEmbed is not installed, skipping") + remote_client = QdrantClient(prefer_grpc=prefer_grpc) + local_kwargs = {} + local_client._client.update_vectors = arg_interceptor( + local_client._client.update_vectors, local_kwargs + ) + + dense_doc_1 = models.Document( + text="hello world", + model=DENSE_MODEL_NAME, + ) + dense_doc_2 = models.Document(text="bye world", model=DENSE_MODEL_NAME) + dense_doc_3 = models.Document(text="goodbye world", model=DENSE_MODEL_NAME) + # region unnamed + points = [ + models.PointStruct(id=1, vector=dense_doc_1), + models.PointStruct(id=2, vector=dense_doc_2), + ] + + populate_dense_collection(local_client, points) + populate_dense_collection(remote_client, points) + + point_vectors = [ + models.PointVectors(id=1, vector=dense_doc_2), + models.PointVectors(id=2, vector=dense_doc_3), + ] + + local_client.update_vectors(COLLECTION_NAME, point_vectors) + remote_client.update_vectors(COLLECTION_NAME, point_vectors) + current_vectors = local_kwargs["points"] + assert all([isinstance(vector.vector, list) for vector in current_vectors]) + + compare_collections( + local_client, + remote_client, + num_vectors=10, + collection_name=COLLECTION_NAME, + ) + + local_client.delete_collection(COLLECTION_NAME) + remote_client.delete_collection(COLLECTION_NAME) + # endregion + + # region named + points = [ + models.PointStruct(id=1, vector={"text": dense_doc_1}), + models.PointStruct(id=2, vector={"text": dense_doc_2}), + ] + + populate_dense_collection(local_client, points, vector_name="text") + populate_dense_collection(remote_client, points, vector_name="text") + + point_vectors = [ + models.PointVectors(id=1, vector={"text": dense_doc_2}), + models.PointVectors(id=2, vector={"text": dense_doc_3}), + ] + + local_client.update_vectors(COLLECTION_NAME, point_vectors) + remote_client.update_vectors(COLLECTION_NAME, point_vectors) + current_vectors = local_kwargs["points"] + assert all([isinstance(vector.vector["text"], list) for vector in current_vectors]) + + compare_collections( + local_client, + remote_client, + num_vectors=10, + collection_name=COLLECTION_NAME, ) local_client.delete_collection(COLLECTION_NAME) remote_client.delete_collection(COLLECTION_NAME) + # endregion @pytest.mark.parametrize("prefer_grpc", [True, False]) @@ -778,15 +1111,23 @@ def test_propagate_options(prefer_grpc): local_client.upsert(COLLECTION_NAME, points) remote_client.upsert(COLLECTION_NAME, points) - assert local_client.embedding_models[DENSE_MODEL_NAME].model.lazy_load - assert local_client.sparse_embedding_models[SPARSE_MODEL_NAME].model.lazy_load - assert local_client.late_interaction_embedding_models[COLBERT_MODEL_NAME].model.lazy_load - assert local_client.image_embedding_models[DENSE_IMAGE_MODEL_NAME].model.lazy_load - - local_client.embedding_models.clear() - local_client.sparse_embedding_models.clear() - local_client.late_interaction_embedding_models.clear() - local_client.image_embedding_models.clear() + assert local_client._model_embedder.embedder.embedding_models[DENSE_MODEL_NAME][ + 0 + ].model.model.lazy_load + assert local_client._model_embedder.embedder.sparse_embedding_models[SPARSE_MODEL_NAME][ + 0 + ].model.model.lazy_load + assert local_client._model_embedder.embedder.late_interaction_embedding_models[ + COLBERT_MODEL_NAME + ][0].model.model.lazy_load + assert local_client._model_embedder.embedder.image_embedding_models[DENSE_IMAGE_MODEL_NAME][ + 0 + ].model.model.lazy_load + + local_client._model_embedder.embedder.embedding_models.clear() + local_client._model_embedder.embedder.sparse_embedding_models.clear() + local_client._model_embedder.embedder.late_interaction_embedding_models.clear() + local_client._model_embedder.embedder.image_embedding_models.clear() inference_object_dense_doc_1 = models.InferenceObject( object="hello world", @@ -827,48 +1168,18 @@ def test_propagate_options(prefer_grpc): local_client.upsert(COLLECTION_NAME, points) remote_client.upsert(COLLECTION_NAME, points) - assert local_client.embedding_models[DENSE_MODEL_NAME].model.lazy_load - assert local_client.sparse_embedding_models[SPARSE_MODEL_NAME].model.lazy_load - assert local_client.late_interaction_embedding_models[COLBERT_MODEL_NAME].model.lazy_load - assert local_client.image_embedding_models[DENSE_IMAGE_MODEL_NAME].model.lazy_load - - -@pytest.mark.parametrize("prefer_grpc", [True, False]) -def test_image(prefer_grpc): - local_client = QdrantClient(":memory:") - if not local_client._FASTEMBED_INSTALLED: - pytest.skip("FastEmbed is not installed, skipping") - remote_client = QdrantClient(prefer_grpc=prefer_grpc) - local_kwargs = {} - local_client._client.upsert = arg_interceptor(local_client._client.upsert, local_kwargs) - - dense_image_1 = models.Image(image=TEST_IMAGE_PATH, model=DENSE_IMAGE_MODEL_NAME) - points = [ - models.PointStruct(id=i, vector=dense_img) for i, dense_img in enumerate([dense_image_1]) - ] - - for client in local_client, remote_client: - if client.collection_exists(COLLECTION_NAME): - client.delete_collection(COLLECTION_NAME) - vector_params = models.VectorParams(size=DENSE_IMAGE_DIM, distance=models.Distance.COSINE) - client.create_collection(COLLECTION_NAME, vectors_config=vector_params) - client.upsert(COLLECTION_NAME, points) - - vec_points = local_kwargs["points"] - assert all([isinstance(vec_point.vector, list) for vec_point in vec_points]) - assert local_client.scroll(COLLECTION_NAME, limit=1, with_vectors=True)[0] - compare_collections( - local_client, - remote_client, - num_vectors=10, - collection_name=COLLECTION_NAME, - ) - - local_client.query_points(COLLECTION_NAME, dense_image_1) - remote_client.query_points(COLLECTION_NAME, dense_image_1) - - local_client.delete_collection(COLLECTION_NAME) - remote_client.delete_collection(COLLECTION_NAME) + assert local_client._model_embedder.embedder.embedding_models[DENSE_MODEL_NAME][ + 0 + ].model.model.lazy_load + assert local_client._model_embedder.embedder.sparse_embedding_models[SPARSE_MODEL_NAME][ + 0 + ].model.model.lazy_load + assert local_client._model_embedder.embedder.late_interaction_embedding_models[ + COLBERT_MODEL_NAME + ][0].model.model.lazy_load + assert local_client._model_embedder.embedder.image_embedding_models[DENSE_IMAGE_MODEL_NAME][ + 0 + ].model.model.lazy_load @pytest.mark.parametrize("prefer_grpc", [True, False]) @@ -968,3 +1279,377 @@ def test_inference_object(prefer_grpc): local_client.delete_collection(COLLECTION_NAME) remote_client.delete_collection(COLLECTION_NAME) + + +@pytest.mark.parametrize("prefer_grpc", [False, True]) +@pytest.mark.parametrize("parallel", [1, 2]) +def test_upload_mixed_batches_upload_points(prefer_grpc, parallel): + local_client = QdrantClient(":memory:") + if not local_client._FASTEMBED_INSTALLED: + pytest.skip("FastEmbed is not installed, skipping") + remote_client = QdrantClient(prefer_grpc=prefer_grpc) + half_dense_dim = DENSE_DIM // 2 + batch_size = 2 + + ref_vector = [0.0, 0.2] * half_dense_dim + norm_ref_vector = (np.array(ref_vector) / np.linalg.norm(ref_vector)).tolist() + + # region separate plain batches + points = [ + models.PointStruct( + id=1, vector=models.Document(text="hello world", model=DENSE_MODEL_NAME) + ), + models.PointStruct(id=2, vector=models.Document(text="bye world", model=DENSE_MODEL_NAME)), + models.PointStruct(id=3, vector=ref_vector), + models.PointStruct(id=4, vector=[0.1, 0.2] * half_dense_dim), + ] + + vectors_config = models.VectorParams(size=DENSE_DIM, distance=models.Distance.COSINE) + local_client.create_collection(COLLECTION_NAME, vectors_config=vectors_config) + if remote_client.collection_exists(COLLECTION_NAME): + remote_client.delete_collection(COLLECTION_NAME) + remote_client.create_collection(COLLECTION_NAME, vectors_config=vectors_config) + + local_client.upload_points( + COLLECTION_NAME, points, batch_size=batch_size, wait=True, parallel=parallel + ) + remote_client.upload_points( + COLLECTION_NAME, points, batch_size=batch_size, wait=True, parallel=parallel + ) + + assert remote_client.count(COLLECTION_NAME).count == len(points) + assert np.allclose( + remote_client.retrieve(COLLECTION_NAME, ids=[3], with_vectors=True)[0].vector, + norm_ref_vector, + ) + + compare_collections( + local_client, + remote_client, + num_vectors=10, + collection_name=COLLECTION_NAME, + ) + + local_client.delete_collection(COLLECTION_NAME) + remote_client.delete_collection(COLLECTION_NAME) + # endregion + + # region mixed plain batches + points = [ + models.PointStruct( + id=1, vector=models.Document(text="hello world", model=DENSE_MODEL_NAME) + ), + models.PointStruct(id=2, vector=ref_vector), + models.PointStruct(id=3, vector=models.Document(text="bye world", model=DENSE_MODEL_NAME)), + models.PointStruct(id=4, vector=[0.1, 0.2] * half_dense_dim), + ] + + local_client.create_collection(COLLECTION_NAME, vectors_config=vectors_config) + remote_client.create_collection(COLLECTION_NAME, vectors_config=vectors_config) + + local_client.upload_points( + COLLECTION_NAME, points, batch_size=batch_size, wait=True, parallel=parallel + ) + remote_client.upload_points( + COLLECTION_NAME, points, batch_size=batch_size, wait=True, parallel=parallel + ) + + assert remote_client.count(COLLECTION_NAME).count == len(points) + assert np.allclose( + remote_client.retrieve(COLLECTION_NAME, ids=[2], with_vectors=True)[0].vector, + norm_ref_vector, + ) + + compare_collections( + local_client, + remote_client, + num_vectors=10, + collection_name=COLLECTION_NAME, + ) + + local_client.delete_collection(COLLECTION_NAME) + remote_client.delete_collection(COLLECTION_NAME) + # endregion + + # region mixed named batches + + vectors_config = { + "text": models.VectorParams(size=DENSE_DIM, distance=models.Distance.COSINE), + "plain": models.VectorParams(size=DENSE_DIM, distance=models.Distance.COSINE), + } + points = [ + models.PointStruct( + id=1, + vector={ + "text": models.Document(text="hello world", model=DENSE_MODEL_NAME), + "plain": [0.1, 0.2] * half_dense_dim, + }, + ), + models.PointStruct( + id=2, + vector={ + "plain": ref_vector, + "text": models.Document(text="bye world", model=DENSE_MODEL_NAME), + }, + ), + models.PointStruct( + id=3, + vector={"plain": [0.3, 0.2] * half_dense_dim}, + ), + models.PointStruct( + id=4, + vector={"text": models.Document(text="bye world", model=DENSE_MODEL_NAME)}, + ), + ] + + local_client.create_collection(COLLECTION_NAME, vectors_config=vectors_config) + remote_client.create_collection(COLLECTION_NAME, vectors_config=vectors_config) + + local_client.upload_points( + COLLECTION_NAME, points, batch_size=batch_size, wait=True, parallel=parallel + ) + remote_client.upload_points( + COLLECTION_NAME, points, batch_size=batch_size, wait=True, parallel=parallel + ) + + assert remote_client.count(COLLECTION_NAME).count == len(points) + assert np.allclose( + remote_client.retrieve(COLLECTION_NAME, ids=[2], with_vectors=True)[0].vector["plain"], + norm_ref_vector, + ) + compare_collections( + local_client, + remote_client, + num_vectors=10, + collection_name=COLLECTION_NAME, + ) + + local_client.delete_collection(COLLECTION_NAME) + remote_client.delete_collection(COLLECTION_NAME) + # endregion + + +@pytest.mark.parametrize("prefer_grpc", [False, True]) +@pytest.mark.parametrize("parallel", [1, 2]) +def test_upload_mixed_batches_upload_collection(prefer_grpc, parallel): + local_client = QdrantClient(":memory:") + if not local_client._FASTEMBED_INSTALLED: + pytest.skip("FastEmbed is not installed, skipping") + remote_client = QdrantClient(prefer_grpc=prefer_grpc) + half_dense_dim = DENSE_DIM // 2 + batch_size = 2 + ref_vector = [0.0, 0.2] * half_dense_dim + norm_ref_vector = (np.array(ref_vector) / np.linalg.norm(ref_vector)).tolist() + + # region separate plain batches + ids = [0, 1, 2, 3] + vectors = [ + models.Document(text="hello world", model=DENSE_MODEL_NAME), + models.Document(text="bye world", model=DENSE_MODEL_NAME), + ref_vector, + [0.1, 0.2] * half_dense_dim, + ] + + vectors_config = models.VectorParams(size=DENSE_DIM, distance=models.Distance.COSINE) + local_client.create_collection(COLLECTION_NAME, vectors_config=vectors_config) + if remote_client.collection_exists(COLLECTION_NAME): + remote_client.delete_collection(COLLECTION_NAME) + remote_client.create_collection(COLLECTION_NAME, vectors_config=vectors_config) + + local_client.upload_collection( + COLLECTION_NAME, + ids=ids, + vectors=vectors, + batch_size=batch_size, + wait=True, + parallel=parallel, + ) + remote_client.upload_collection( + COLLECTION_NAME, + ids=ids, + vectors=vectors, + batch_size=batch_size, + wait=True, + parallel=parallel, + ) + + assert remote_client.count(COLLECTION_NAME).count == len(vectors) + assert np.allclose( + remote_client.retrieve(COLLECTION_NAME, ids=[2], with_vectors=True)[0].vector, + norm_ref_vector, + ) + + compare_collections( + local_client, + remote_client, + num_vectors=10, + collection_name=COLLECTION_NAME, + ) + + local_client.delete_collection(COLLECTION_NAME) + remote_client.delete_collection(COLLECTION_NAME) + # endregion + + # region mixed plain batches + vectors = [ + models.Document(text="hello world", model=DENSE_MODEL_NAME), + ref_vector, + models.Document(text="bye world", model=DENSE_MODEL_NAME), + [0.1, 0.2] * half_dense_dim, + ] + + local_client.create_collection(COLLECTION_NAME, vectors_config=vectors_config) + remote_client.create_collection(COLLECTION_NAME, vectors_config=vectors_config) + + local_client.upload_collection( + COLLECTION_NAME, ids=ids, vectors=vectors, batch_size=batch_size, parallel=parallel + ) + remote_client.upload_collection( + COLLECTION_NAME, + ids=ids, + vectors=vectors, + batch_size=batch_size, + wait=True, + parallel=parallel, + ) + + assert remote_client.count(COLLECTION_NAME).count == len(vectors) + assert np.allclose( + remote_client.retrieve(COLLECTION_NAME, ids=[1], with_vectors=True)[0].vector, + norm_ref_vector, + ) + + compare_collections( + local_client, + remote_client, + num_vectors=10, + collection_name=COLLECTION_NAME, + ) + + local_client.delete_collection(COLLECTION_NAME) + remote_client.delete_collection(COLLECTION_NAME) + # endregion + + # region mixed named batches + + vectors_config = { + "text": models.VectorParams(size=DENSE_DIM, distance=models.Distance.COSINE), + "plain": models.VectorParams(size=DENSE_DIM, distance=models.Distance.COSINE), + } + vectors = [ + { + "text": models.Document(text="hello world", model=DENSE_MODEL_NAME), + "plain": [0.0, 0.2] * half_dense_dim, + }, + { + "plain": ref_vector, + "text": models.Document(text="bye world", model=DENSE_MODEL_NAME), + }, + {"plain": [0.3, 0.2] * half_dense_dim}, + {"text": models.Document(text="bye world", model=DENSE_MODEL_NAME)}, + ] + + local_client.create_collection(COLLECTION_NAME, vectors_config=vectors_config) + remote_client.create_collection(COLLECTION_NAME, vectors_config=vectors_config) + + local_client.upload_collection( + COLLECTION_NAME, + ids=ids, + vectors=vectors, + batch_size=batch_size, + wait=True, + parallel=parallel, + ) + remote_client.upload_collection( + COLLECTION_NAME, + ids=ids, + vectors=vectors, + batch_size=batch_size, + wait=True, + parallel=parallel, + ) + + assert remote_client.count(COLLECTION_NAME).count == len(vectors) + assert np.allclose( + remote_client.retrieve(COLLECTION_NAME, ids=[1], with_vectors=True)[0].vector["plain"], + norm_ref_vector, + ) + + compare_collections( + local_client, + remote_client, + num_vectors=10, + collection_name=COLLECTION_NAME, + ) + + local_client.delete_collection(COLLECTION_NAME) + remote_client.delete_collection(COLLECTION_NAME) + # endregion + + +@pytest.mark.parametrize("prefer_grpc", [True, False]) +def test_upsert_batch_with_different_options(prefer_grpc): + bm25_name = "Qdrant/bm25" + local_client = QdrantClient(":memory:") + if not local_client._FASTEMBED_INSTALLED: + pytest.skip("FastEmbed is not installed, skipping") + remote_client = QdrantClient(prefer_grpc=prefer_grpc) + + sparse_doc_1 = models.Document( + text="running run", model=bm25_name, options={"language": "english"} + ) + sparse_doc_2 = models.Document( + text="running run", model=bm25_name, options={"language": "german"} + ) + sparse_doc_3 = models.Document( + text="running run", model=bm25_name, options={"language": "english"} + ) + sparse_doc_4 = models.Document( + text="running run", model=bm25_name, options={"language": "german"} + ) + + sparse_vectors_config = { + "sparse-text-en": models.SparseVectorParams(modifier=models.Modifier.IDF), + "sparse-text-de": models.SparseVectorParams(modifier=models.Modifier.IDF), + } + if remote_client.collection_exists(COLLECTION_NAME): + remote_client.delete_collection(COLLECTION_NAME) + + local_client.create_collection( + COLLECTION_NAME, vectors_config={}, sparse_vectors_config=sparse_vectors_config + ) + remote_client.create_collection( + COLLECTION_NAME, vectors_config={}, sparse_vectors_config=sparse_vectors_config + ) + points = [ + models.PointStruct( + id=0, vector={"sparse-text-en": sparse_doc_1, "sparse-text-de": sparse_doc_2} + ), + models.PointStruct(id=1, vector={"sparse-text-en": sparse_doc_3}), + models.PointStruct(id=2, vector={"sparse-text-de": sparse_doc_4}), + ] + + local_client.upsert(COLLECTION_NAME, points) + remote_client.upsert(COLLECTION_NAME, points) + + read_points, _ = local_client.scroll(COLLECTION_NAME, limit=4, with_vectors=True) + assert len(read_points) == 3 + assert ( + read_points[0].vector["sparse-text-en"].indices + != read_points[0].vector["sparse-text-de"].indices + ) + assert ( + read_points[0].vector["sparse-text-en"].indices + == read_points[1].vector["sparse-text-en"].indices + ) + assert ( + read_points[0].vector["sparse-text-de"].indices + == read_points[2].vector["sparse-text-de"].indices + ) + + compare_collections( + local_client, + remote_client, + num_vectors=10, + collection_name=COLLECTION_NAME, + ) diff --git a/tests/test_fastembed.py b/tests/test_fastembed.py index 9c02e257..d089dd9b 100644 --- a/tests/test_fastembed.py +++ b/tests/test_fastembed.py @@ -201,3 +201,21 @@ def test_idf_models(): # the only sparse model without IDF is SPLADE, however it's too large for tests, so we don't test how non-idf # models work + + +def test_get_embedding_size(): + local_client = QdrantClient(":memory:") + + if not local_client._FASTEMBED_INSTALLED: + pytest.skip("FastEmbed is not installed, skipping test") + + assert local_client.get_embedding_size() == 384 + + assert local_client.get_embedding_size(model_name="BAAI/bge-base-en-v1.5") == 768 + + assert local_client.get_embedding_size(model_name="Qdrant/resnet50-onnx") == 2048 + + assert local_client.get_embedding_size(model_name="colbert-ir/colbertv2.0") == 128 + + with pytest.raises(ValueError, match="Sparse embeddings do not have a fixed embedding size."): + local_client.get_embedding_size(model_name="Qdrant/bm25") diff --git a/tools/async_client_generator/fastembed_generator.py b/tools/async_client_generator/fastembed_generator.py index fe398c3f..cca82837 100644 --- a/tools/async_client_generator/fastembed_generator.py +++ b/tools/async_client_generator/fastembed_generator.py @@ -57,6 +57,7 @@ def get_async_methods(class_obj: type) -> list[str]: "set_sparse_model", "get_vector_field_name", "get_sparse_vector_field_name", + "get_embedding_size", "get_fastembed_vector_params", "get_fastembed_sparse_vector_params", "embedding_model_name",