diff --git a/fastembed/text/clip_embedding.py b/fastembed/text/clip_embedding.py index b0721bf1..7a66a921 100644 --- a/fastembed/text/clip_embedding.py +++ b/fastembed/text/clip_embedding.py @@ -22,6 +22,8 @@ class CLIPOnnxEmbedding(OnnxTextEmbedding): + supported_models = supported_clip_models + @classmethod def _get_worker_class(cls) -> Type[TextEmbeddingWorker]: return CLIPEmbeddingWorker @@ -33,7 +35,11 @@ def list_supported_models(cls) -> list[dict[str, Any]]: Returns: list[dict[str, Any]]: A list of dictionaries containing the model information. """ - return supported_clip_models + return cls.supported_models + + @classmethod + def add_custom_model(cls, model_info: dict[str, Any]): + cls.supported_models.append(model_info) def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]: return output.model_output diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py index cc479c21..f481a7a9 100644 --- a/fastembed/text/multitask_embedding.py +++ b/fastembed/text/multitask_embedding.py @@ -41,6 +41,7 @@ class Task(int, Enum): class JinaEmbeddingV3(PooledNormalizedEmbedding): PASSAGE_TASK = Task.RETRIEVAL_PASSAGE QUERY_TASK = Task.RETRIEVAL_QUERY + supported_models = supported_multitask_models def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -52,7 +53,11 @@ def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]: @classmethod def list_supported_models(cls) -> list[dict[str, Any]]: - return supported_multitask_models + return cls.supported_models + + @classmethod + def add_custom_model(cls, model_info: dict[str, Any]): + cls.supported_models.append(model_info) def _preprocess_onnx_input( self, onnx_input: dict[str, np.ndarray], **kwargs diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index a93920cf..32b78d28 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -173,6 +173,8 @@ class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[np.ndarray]): """Implementation of the Flag Embedding model.""" + supported_models = supported_onnx_models + @classmethod def list_supported_models(cls) -> list[dict[str, Any]]: """ @@ -181,7 +183,11 @@ def list_supported_models(cls) -> list[dict[str, Any]]: Returns: list[dict[str, Any]]: A list of dictionaries containing the model information. """ - return supported_onnx_models + return cls.supported_models + + @classmethod + def add_custom_model(cls, model_info: dict[str, Any]): + cls.supported_models.append(model_info) def __init__( self, diff --git a/fastembed/text/pooled_embedding.py b/fastembed/text/pooled_embedding.py index 063c47bd..721fc80d 100644 --- a/fastembed/text/pooled_embedding.py +++ b/fastembed/text/pooled_embedding.py @@ -79,6 +79,8 @@ class PooledEmbedding(OnnxTextEmbedding): + supported_models = supported_pooled_models + @classmethod def _get_worker_class(cls) -> Type[TextEmbeddingWorker]: return PooledEmbeddingWorker @@ -101,7 +103,11 @@ def list_supported_models(cls) -> list[dict[str, Any]]: Returns: list[dict[str, Any]]: A list of dictionaries containing the model information. """ - return supported_pooled_models + return cls.supported_models + + @classmethod + def add_custom_model(cls, model_info: dict[str, Any]): + cls.supported_models.append(model_info) def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]: if output.attention_mask is None: diff --git a/fastembed/text/pooled_normalized_embedding.py b/fastembed/text/pooled_normalized_embedding.py index 5d5cef75..cc48a14e 100644 --- a/fastembed/text/pooled_normalized_embedding.py +++ b/fastembed/text/pooled_normalized_embedding.py @@ -88,6 +88,8 @@ class PooledNormalizedEmbedding(PooledEmbedding): + supported_models = supported_pooled_normalized_models + @classmethod def _get_worker_class(cls) -> Type[TextEmbeddingWorker]: return PooledNormalizedEmbeddingWorker @@ -99,7 +101,11 @@ def list_supported_models(cls) -> list[dict[str, Any]]: Returns: list[dict[str, Any]]: A list of dictionaries containing the model information. """ - return supported_pooled_normalized_models + return cls.supported_models + + @classmethod + def add_custom_model(cls, model_info: dict[str, Any]): + cls.supported_models.append(model_info) def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]: if output.attention_mask is None: diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index 30e4bba5..eac2595c 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -50,6 +50,38 @@ def list_supported_models(cls) -> list[dict[str, Any]]: result.extend(embedding.list_supported_models()) return result + @classmethod + def add_custom_model( + cls, model_info: dict[str, Any], mean_pooling: bool = True, normalization: bool = False + ) -> None: + """ + Register a custom model so that TextEmbedding(...) can find it later. + + Args: + model_info: Dictionary describing the model, e.g.: + { + "model": "alibaba/blablabla", + "dim": 512, + "description": "...", + "license": "apache-2.0", + "size_in_GB": 1.23, + "sources": { ... } # optional + } + mean_pooling: apply mean_pooling or not. + normalization: apply normalization or not. + + Returns: + None + """ + if mean_pooling and not normalization: + PooledEmbedding.add_custom_model(model_info) + elif mean_pooling and normalization: + PooledNormalizedEmbedding.add_custom_model(model_info) + elif "clip" in model_info["model"].lower(): + CLIPOnnxEmbedding.add_custom_model(model_info) + else: + OnnxTextEmbedding.add_custom_model(model_info) + def __init__( self, model_name: str = "BAAI/bge-small-en-v1.5", diff --git a/tests/test_add_custom_model.py b/tests/test_add_custom_model.py new file mode 100644 index 00000000..e493daf0 --- /dev/null +++ b/tests/test_add_custom_model.py @@ -0,0 +1,115 @@ +import os +import numpy as np +import pytest + +from fastembed.text.text_embedding import TextEmbedding +from tests.utils import delete_model_cache + +canonical_vectors = [ + { + "model": "intfloat/multilingual-e5-small", + "mean_pooling": True, + "normalization": True, + "canonical_vector": [3.1317e-02, 3.0939e-02, -3.5117e-02, -6.7274e-02, 8.5084e-02], + }, + { + "model": "intfloat/multilingual-e5-small", + "mean_pooling": True, + "normalization": False, + "canonical_vector": [1.4604e-01, 1.4428e-01, -1.6376e-01, -3.1372e-01, 3.9677e-01], + }, + { + "model": "mixedbread-ai/mxbai-embed-xsmall-v1", + "mean_pooling": False, + "normalization": False, + "canonical_vector": [ + 2.49407589e-02, + 1.00189969e-02, + 1.07807154e-02, + 3.63860987e-02, + -2.27128249e-02, + ], + }, +] + +DIMENSIONS = { + "intfloat/multilingual-e5-small": 384, + "mixedbread-ai/mxbai-embed-xsmall-v1": 384, +} + +SOURCES = { + "intfloat/multilingual-e5-small": "intfloat/multilingual-e5-small", + "mixedbread-ai/mxbai-embed-xsmall-v1": "mixedbread-ai/mxbai-embed-xsmall-v1", +} + + +@pytest.mark.parametrize("scenario", canonical_vectors) +def test_add_custom_model_variations(scenario): + is_ci = bool(os.getenv("CI", False)) + + base_model_name = scenario["model"] + mean_pooling = scenario["mean_pooling"] + normalization = scenario["normalization"] + cv = np.array(scenario["canonical_vector"], dtype=np.float32) + + backup_supported_models = {} + for embedding_cls in TextEmbedding.EMBEDDINGS_REGISTRY: + backup_supported_models[embedding_cls] = embedding_cls.list_supported_models().copy() + + suffixes = [] + suffixes.append("mean" if mean_pooling else "no-mean") + suffixes.append("norm" if normalization else "no-norm") + suffix_str = "-".join(suffixes) + + custom_model_name = f"{base_model_name}-{suffix_str}" + + dim = DIMENSIONS[base_model_name] + hf_source = SOURCES[base_model_name] + + model_info = { + "model": custom_model_name, + "dim": dim, + "description": f"{base_model_name} with {suffix_str}", + "license": "mit", + "size_in_GB": 0.13, + "sources": { + "hf": hf_source, + }, + "model_file": "onnx/model.onnx", + "additional_files": [], + } + + if is_ci and model_info["size_in_GB"] > 1.0: + pytest.skip( + f"Skipping {custom_model_name} on CI due to size_in_GB={model_info['size_in_GB']}" + ) + + try: + TextEmbedding.add_custom_model( + model_info=model_info, mean_pooling=mean_pooling, normalization=normalization + ) + + model = TextEmbedding(model_name=custom_model_name) + + docs = ["hello world", "flag embedding"] + embeddings = list(model.embed(docs)) + embeddings = np.stack(embeddings, axis=0) + + assert embeddings.shape == ( + 2, + dim, + ), f"Expected shape (2, {dim}) for {custom_model_name}, but got {embeddings.shape}" + + num_compare_dims = cv.shape[0] + assert np.allclose( + embeddings[0, :num_compare_dims], cv, atol=1e-3 + ), f"Embedding mismatch for {custom_model_name} (first {num_compare_dims} dims)." + + assert not np.allclose(embeddings[0, :], 0.0), "Embedding should not be all zeros." + + if is_ci: + delete_model_cache(model.model._model_dir) + + finally: + for embedding_cls, old_list in backup_supported_models.items(): + embedding_cls.supported_models = old_list