-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add custom model #463
base: main
Are you sure you want to change the base?
Add custom model #463
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -50,6 +50,38 @@ def list_supported_models(cls) -> list[dict[str, Any]]: | |||||||||
result.extend(embedding.list_supported_models()) | ||||||||||
return result | ||||||||||
|
||||||||||
@classmethod | ||||||||||
def add_custom_model( | ||||||||||
cls, model_info: dict[str, Any], mean_pooling: bool = True, normalization: bool = False | ||||||||||
) -> None: | ||||||||||
""" | ||||||||||
Register a custom model so that TextEmbedding(...) can find it later. | ||||||||||
|
||||||||||
Args: | ||||||||||
model_info: Dictionary describing the model, e.g.: | ||||||||||
{ | ||||||||||
"model": "alibaba/blablabla", | ||||||||||
"dim": 512, | ||||||||||
"description": "...", | ||||||||||
"license": "apache-2.0", | ||||||||||
"size_in_GB": 1.23, | ||||||||||
"sources": { ... } # optional | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't that be mandatory ? We use sources to download model/config.
Suggested change
|
||||||||||
} | ||||||||||
mean_pooling: apply mean_pooling or not. | ||||||||||
normalization: apply normalization or not. | ||||||||||
|
||||||||||
Returns: | ||||||||||
None | ||||||||||
""" | ||||||||||
if mean_pooling and not normalization: | ||||||||||
PooledEmbedding.add_custom_model(model_info) | ||||||||||
elif mean_pooling and normalization: | ||||||||||
PooledNormalizedEmbedding.add_custom_model(model_info) | ||||||||||
elif "clip" in model_info["model"].lower(): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not always the case.
Suggested change
|
||||||||||
CLIPOnnxEmbedding.add_custom_model(model_info) | ||||||||||
else: | ||||||||||
OnnxTextEmbedding.add_custom_model(model_info) | ||||||||||
|
||||||||||
def __init__( | ||||||||||
self, | ||||||||||
model_name: str = "BAAI/bge-small-en-v1.5", | ||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import os | ||
import numpy as np | ||
import pytest | ||
|
||
from fastembed.text.text_embedding import TextEmbedding | ||
from tests.utils import delete_model_cache | ||
|
||
canonical_vectors = [ | ||
{ | ||
"model": "intfloat/multilingual-e5-small", | ||
"mean_pooling": True, | ||
"normalization": True, | ||
"canonical_vector": [3.1317e-02, 3.0939e-02, -3.5117e-02, -6.7274e-02, 8.5084e-02], | ||
}, | ||
{ | ||
"model": "intfloat/multilingual-e5-small", | ||
"mean_pooling": True, | ||
"normalization": False, | ||
"canonical_vector": [1.4604e-01, 1.4428e-01, -1.6376e-01, -3.1372e-01, 3.9677e-01], | ||
}, | ||
{ | ||
"model": "mixedbread-ai/mxbai-embed-xsmall-v1", | ||
"mean_pooling": False, | ||
"normalization": False, | ||
"canonical_vector": [ | ||
2.49407589e-02, | ||
1.00189969e-02, | ||
1.07807154e-02, | ||
3.63860987e-02, | ||
-2.27128249e-02, | ||
], | ||
}, | ||
] | ||
|
||
DIMENSIONS = { | ||
"intfloat/multilingual-e5-small": 384, | ||
"mixedbread-ai/mxbai-embed-xsmall-v1": 384, | ||
} | ||
|
||
SOURCES = { | ||
"intfloat/multilingual-e5-small": "intfloat/multilingual-e5-small", | ||
"mixedbread-ai/mxbai-embed-xsmall-v1": "mixedbread-ai/mxbai-embed-xsmall-v1", | ||
} | ||
|
||
|
||
@pytest.mark.parametrize("scenario", canonical_vectors) | ||
def test_add_custom_model_variations(scenario): | ||
is_ci = bool(os.getenv("CI", False)) | ||
|
||
base_model_name = scenario["model"] | ||
mean_pooling = scenario["mean_pooling"] | ||
normalization = scenario["normalization"] | ||
cv = np.array(scenario["canonical_vector"], dtype=np.float32) | ||
|
||
backup_supported_models = {} | ||
for embedding_cls in TextEmbedding.EMBEDDINGS_REGISTRY: | ||
backup_supported_models[embedding_cls] = embedding_cls.list_supported_models().copy() | ||
|
||
suffixes = [] | ||
suffixes.append("mean" if mean_pooling else "no-mean") | ||
suffixes.append("norm" if normalization else "no-norm") | ||
suffix_str = "-".join(suffixes) | ||
|
||
custom_model_name = f"{base_model_name}-{suffix_str}" | ||
|
||
dim = DIMENSIONS[base_model_name] | ||
hf_source = SOURCES[base_model_name] | ||
|
||
model_info = { | ||
"model": custom_model_name, | ||
"dim": dim, | ||
"description": f"{base_model_name} with {suffix_str}", | ||
"license": "mit", | ||
"size_in_GB": 0.13, | ||
"sources": { | ||
"hf": hf_source, | ||
}, | ||
"model_file": "onnx/model.onnx", | ||
"additional_files": [], | ||
} | ||
|
||
if is_ci and model_info["size_in_GB"] > 1.0: | ||
pytest.skip( | ||
f"Skipping {custom_model_name} on CI due to size_in_GB={model_info['size_in_GB']}" | ||
) | ||
|
||
try: | ||
TextEmbedding.add_custom_model( | ||
model_info=model_info, mean_pooling=mean_pooling, normalization=normalization | ||
) | ||
|
||
model = TextEmbedding(model_name=custom_model_name) | ||
|
||
docs = ["hello world", "flag embedding"] | ||
embeddings = list(model.embed(docs)) | ||
embeddings = np.stack(embeddings, axis=0) | ||
|
||
assert embeddings.shape == ( | ||
2, | ||
dim, | ||
), f"Expected shape (2, {dim}) for {custom_model_name}, but got {embeddings.shape}" | ||
|
||
num_compare_dims = cv.shape[0] | ||
assert np.allclose( | ||
embeddings[0, :num_compare_dims], cv, atol=1e-3 | ||
), f"Embedding mismatch for {custom_model_name} (first {num_compare_dims} dims)." | ||
|
||
assert not np.allclose(embeddings[0, :], 0.0), "Embedding should not be all zeros." | ||
|
||
if is_ci: | ||
delete_model_cache(model.model._model_dir) | ||
|
||
finally: | ||
for embedding_cls, old_list in backup_supported_models.items(): | ||
embedding_cls.supported_models = old_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no need to redefine this method in the subclasses if they fetch the right cls variable