Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom model #463

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion fastembed/text/clip_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@


class CLIPOnnxEmbedding(OnnxTextEmbedding):
supported_models = supported_clip_models

@classmethod
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
return CLIPEmbeddingWorker
Expand All @@ -33,7 +35,11 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
Returns:
list[dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_clip_models
return cls.supported_models

@classmethod
def add_custom_model(cls, model_info: dict[str, Any]):
cls.supported_models.append(model_info)

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
return output.model_output
Expand Down
7 changes: 6 additions & 1 deletion fastembed/text/multitask_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Task(int, Enum):
class JinaEmbeddingV3(PooledNormalizedEmbedding):
PASSAGE_TASK = Task.RETRIEVAL_PASSAGE
QUERY_TASK = Task.RETRIEVAL_QUERY
supported_models = supported_multitask_models

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
Expand All @@ -52,7 +53,11 @@ def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]:

@classmethod
def list_supported_models(cls) -> list[dict[str, Any]]:
return supported_multitask_models
return cls.supported_models

@classmethod
def add_custom_model(cls, model_info: dict[str, Any]):
cls.supported_models.append(model_info)

def _preprocess_onnx_input(
self, onnx_input: dict[str, np.ndarray], **kwargs
Expand Down
8 changes: 7 additions & 1 deletion fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@
class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[np.ndarray]):
"""Implementation of the Flag Embedding model."""

supported_models = supported_onnx_models

@classmethod
def list_supported_models(cls) -> list[dict[str, Any]]:
"""
Expand All @@ -181,7 +183,11 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
Returns:
list[dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_onnx_models
return cls.supported_models

@classmethod
def add_custom_model(cls, model_info: dict[str, Any]):
cls.supported_models.append(model_info)

def __init__(
self,
Expand Down
8 changes: 7 additions & 1 deletion fastembed/text/pooled_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@


class PooledEmbedding(OnnxTextEmbedding):
supported_models = supported_pooled_models

@classmethod
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
return PooledEmbeddingWorker
Expand All @@ -101,7 +103,11 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
Returns:
list[dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_pooled_models
return cls.supported_models

@classmethod
def add_custom_model(cls, model_info: dict[str, Any]):
cls.supported_models.append(model_info)
Comment on lines +108 to +110
Copy link
Member

@joein joein Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no need to redefine this method in the subclasses if they fetch the right cls variable


def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
if output.attention_mask is None:
Expand Down
8 changes: 7 additions & 1 deletion fastembed/text/pooled_normalized_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@


class PooledNormalizedEmbedding(PooledEmbedding):
supported_models = supported_pooled_normalized_models

@classmethod
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
return PooledNormalizedEmbeddingWorker
Expand All @@ -99,7 +101,11 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
Returns:
list[dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_pooled_normalized_models
return cls.supported_models

@classmethod
def add_custom_model(cls, model_info: dict[str, Any]):
cls.supported_models.append(model_info)

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
if output.attention_mask is None:
Expand Down
32 changes: 32 additions & 0 deletions fastembed/text/text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,38 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
result.extend(embedding.list_supported_models())
return result

@classmethod
def add_custom_model(
cls, model_info: dict[str, Any], mean_pooling: bool = True, normalization: bool = False
) -> None:
"""
Register a custom model so that TextEmbedding(...) can find it later.

Args:
model_info: Dictionary describing the model, e.g.:
{
"model": "alibaba/blablabla",
"dim": 512,
"description": "...",
"license": "apache-2.0",
"size_in_GB": 1.23,
"sources": { ... } # optional
Copy link
Contributor

@hh-space-invader hh-space-invader Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't that be mandatory ? We use sources to download model/config.

Suggested change
"sources": { ... } # optional
"sources": {
"hf": "alibaba/blablabla"
}

}
mean_pooling: apply mean_pooling or not.
normalization: apply normalization or not.

Returns:
None
"""
if mean_pooling and not normalization:
PooledEmbedding.add_custom_model(model_info)
elif mean_pooling and normalization:
PooledNormalizedEmbedding.add_custom_model(model_info)
elif "clip" in model_info["model"].lower():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not always the case. jinaai/jina-clip-v1 model is added to text models (not clip) as it needs normalization.

Suggested change
elif "clip" in model_info["model"].lower():
elif "clip" in model_info["model"].lower() and not normalization:

CLIPOnnxEmbedding.add_custom_model(model_info)
else:
OnnxTextEmbedding.add_custom_model(model_info)

def __init__(
self,
model_name: str = "BAAI/bge-small-en-v1.5",
Expand Down
115 changes: 115 additions & 0 deletions tests/test_add_custom_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os
import numpy as np
import pytest

from fastembed.text.text_embedding import TextEmbedding
from tests.utils import delete_model_cache

canonical_vectors = [
{
"model": "intfloat/multilingual-e5-small",
"mean_pooling": True,
"normalization": True,
"canonical_vector": [3.1317e-02, 3.0939e-02, -3.5117e-02, -6.7274e-02, 8.5084e-02],
},
{
"model": "intfloat/multilingual-e5-small",
"mean_pooling": True,
"normalization": False,
"canonical_vector": [1.4604e-01, 1.4428e-01, -1.6376e-01, -3.1372e-01, 3.9677e-01],
},
{
"model": "mixedbread-ai/mxbai-embed-xsmall-v1",
"mean_pooling": False,
"normalization": False,
"canonical_vector": [
2.49407589e-02,
1.00189969e-02,
1.07807154e-02,
3.63860987e-02,
-2.27128249e-02,
],
},
]

DIMENSIONS = {
"intfloat/multilingual-e5-small": 384,
"mixedbread-ai/mxbai-embed-xsmall-v1": 384,
}

SOURCES = {
"intfloat/multilingual-e5-small": "intfloat/multilingual-e5-small",
"mixedbread-ai/mxbai-embed-xsmall-v1": "mixedbread-ai/mxbai-embed-xsmall-v1",
}


@pytest.mark.parametrize("scenario", canonical_vectors)
def test_add_custom_model_variations(scenario):
is_ci = bool(os.getenv("CI", False))

base_model_name = scenario["model"]
mean_pooling = scenario["mean_pooling"]
normalization = scenario["normalization"]
cv = np.array(scenario["canonical_vector"], dtype=np.float32)

backup_supported_models = {}
for embedding_cls in TextEmbedding.EMBEDDINGS_REGISTRY:
backup_supported_models[embedding_cls] = embedding_cls.list_supported_models().copy()

suffixes = []
suffixes.append("mean" if mean_pooling else "no-mean")
suffixes.append("norm" if normalization else "no-norm")
suffix_str = "-".join(suffixes)

custom_model_name = f"{base_model_name}-{suffix_str}"

dim = DIMENSIONS[base_model_name]
hf_source = SOURCES[base_model_name]

model_info = {
"model": custom_model_name,
"dim": dim,
"description": f"{base_model_name} with {suffix_str}",
"license": "mit",
"size_in_GB": 0.13,
"sources": {
"hf": hf_source,
},
"model_file": "onnx/model.onnx",
"additional_files": [],
}

if is_ci and model_info["size_in_GB"] > 1.0:
pytest.skip(
f"Skipping {custom_model_name} on CI due to size_in_GB={model_info['size_in_GB']}"
)

try:
TextEmbedding.add_custom_model(
model_info=model_info, mean_pooling=mean_pooling, normalization=normalization
)

model = TextEmbedding(model_name=custom_model_name)

docs = ["hello world", "flag embedding"]
embeddings = list(model.embed(docs))
embeddings = np.stack(embeddings, axis=0)

assert embeddings.shape == (
2,
dim,
), f"Expected shape (2, {dim}) for {custom_model_name}, but got {embeddings.shape}"

num_compare_dims = cv.shape[0]
assert np.allclose(
embeddings[0, :num_compare_dims], cv, atol=1e-3
), f"Embedding mismatch for {custom_model_name} (first {num_compare_dims} dims)."

assert not np.allclose(embeddings[0, :], 0.0), "Embedding should not be all zeros."

if is_ci:
delete_model_cache(model.model._model_dir)

finally:
for embedding_cls, old_list in backup_supported_models.items():
embedding_cls.supported_models = old_list