diff --git a/lilac/embeddings/bge.py b/lilac/embeddings/bge.py new file mode 100644 index 00000000..0c578b1e --- /dev/null +++ b/lilac/embeddings/bge.py @@ -0,0 +1,94 @@ +"""Gegeral Text Embeddings (GTE) model. Open-source model, designed to run on device.""" +import gc +from typing import TYPE_CHECKING, ClassVar, Iterator, Optional + +from typing_extensions import override + +if TYPE_CHECKING: + from FlagEmbedding import BGEM3FlagModel + + +import functools + +from ..schema import Item +from ..signal import TextEmbeddingSignal +from ..splitters.spacy_splitter import clustering_spacy_chunker +from ..tasks import TaskExecutionType +from .embedding import chunked_compute_embedding +from .transformer_utils import SENTENCE_TRANSFORMER_BATCH_SIZE, setup_model_device + +# See https://huggingface.co/spaces/mteb/leaderboard for leaderboard of models. +BGE_M3 = 'BAAI/bge-m3' + + +@functools.cache +def _get_and_cache_bge_m3(model_name: str) -> 'BGEM3FlagModel': + try: + from FlagEmbedding import BGEM3FlagModel + except ImportError: + raise ImportError( + 'Could not import the "FlagEmbedding" python package. ' + 'Please install it with `pip install "lilac[bge]".' + ) + model = BGEM3FlagModel( + 'BAAI/bge-m3', use_fp16=True + ) # Setting use_fp16 to True speeds up computation with a slight performance degradation + return model + return setup_model_device(model, model_name) + + +class BGEM3(TextEmbeddingSignal): + """Computes BGE-M3 embeddings. + +
This embedding runs on-device. See the [model card](https://huggingface.co/BAAI/bge-m3) + for details. + """ + + name: ClassVar[str] = 'bge-m3' + display_name: ClassVar[str] = 'BGE-M3' + local_batch_size: ClassVar[int] = SENTENCE_TRANSFORMER_BATCH_SIZE + local_parallelism: ClassVar[int] = 1 + local_strategy: ClassVar[TaskExecutionType] = 'threads' + supports_garden: ClassVar[bool] = False + + _model_name = BGE_M3 + _model: 'BGEM3FlagModel' + + @override + def setup(self) -> None: + self._model = _get_and_cache_bge_m3(self._model_name) + + @override + def compute(self, docs: list[str]) -> list[Optional[Item]]: + """Call the embedding function.""" + + def _encode(doc): + # Extract the dense vectors from the model. + return self._model.encode(doc)['dense_vecs'] + + # While we get docs in batches of 1024, the chunker expands that by a factor of 3-10. + # The sentence transformer API actually does batching internally, so we pass + # local_batch_size * 16 to allow the library to see all the chunks at once. + return chunked_compute_embedding( + _encode, docs, self.local_batch_size * 16, chunker=clustering_spacy_chunker + ) + + @override + def compute_garden(self, docs: Iterator[str]) -> Iterator[Item]: + raise NotImplementedError('Garden computation is not supported for BGE-M3.') + + @override + def teardown(self) -> None: + if not hasattr(self, '_model'): + return + + self._model.cpu() + del self._model + gc.collect() + + try: + import torch + + torch.cuda.empty_cache() + except ImportError: + pass diff --git a/lilac/signals/default_signals.py b/lilac/signals/default_signals.py index b89988ee..451788ec 100644 --- a/lilac/signals/default_signals.py +++ b/lilac/signals/default_signals.py @@ -1,5 +1,6 @@ """Registers all available default signals.""" +from ..embeddings.bge import BGEM3 from ..embeddings.cohere import Cohere from ..embeddings.gte import GTEBase, GTESmall, GTETiny from ..embeddings.jina import JinaV2Base, JinaV2Small @@ -43,3 +44,5 @@ def register_default_signals() -> None: register_signal(JinaV2Small, exists_ok=True) register_signal(JinaV2Base, exists_ok=True) + + register_signal(BGEM3, exists_ok=True) diff --git a/notebooks/Test.ipynb b/notebooks/Test.ipynb new file mode 100644 index 00000000..e69de29b diff --git a/poetry.lock b/poetry.lock index 0e67239c..f5717225 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,35 @@ # This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +[[package]] +name = "accelerate" +version = "0.27.2" +description = "Accelerate" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "accelerate-0.27.2-py3-none-any.whl", hash = "sha256:a818dd27b9ba24e9eb5030d1b285cf4cdd1b41bbfa675fb4eb2477ddfc097074"}, + {file = "accelerate-0.27.2.tar.gz", hash = "sha256:cc715fe9a8bc7a286259bfb6d65fb78363badd3371e7cbda4e4a4ef34a0010aa"}, +] + +[package.dependencies] +huggingface-hub = "*" +numpy = ">=1.17" +packaging = ">=20.0" +psutil = "*" +pyyaml = "*" +safetensors = ">=0.3.1" +torch = ">=1.10.0" + +[package.extras] +dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "deepspeed (<0.13.0)", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.1.15,<0.2.0)", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.1.15,<0.2.0)"] +rich = ["rich"] +sagemaker = ["sagemaker"] +test-dev = ["bitsandbytes", "datasets", "deepspeed (<0.13.0)", "evaluate", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +test-prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"] +test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"] +testing = ["bitsandbytes", "datasets", "deepspeed (<0.13.0)", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] + [[package]] name = "aiohttp" version = "3.9.1" @@ -1527,6 +1557,23 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1 testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] typing = ["typing-extensions (>=4.8)"] +[[package]] +name = "flagembedding" +version = "1.2.3" +description = "FlagEmbedding" +optional = true +python-versions = "*" +files = [ + {file = "FlagEmbedding-1.2.3.tar.gz", hash = "sha256:f16ab1fa59969dd28e4729f0157a5225ba21a429b3adfa5cddf34c2ad0d0c3be"}, +] + +[package.dependencies] +accelerate = ">=0.20.1" +datasets = "*" +sentence_transformers = "*" +torch = ">=1.6.0" +transformers = ">=4.33.0" + [[package]] name = "floret" version = "0.10.5" @@ -4113,6 +4160,7 @@ files = [ {file = "numpy-1.26.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3c67423b3703f8fbd90f5adaa37f85b5794d3366948efe9a5190a5f3a83fc34e"}, {file = "numpy-1.26.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46f47ee566d98849323f01b349d58f2557f02167ee301e5e28809a8c0e27a2d0"}, {file = "numpy-1.26.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a8474703bffc65ca15853d5fd4d06b18138ae90c17c8d12169968e998e448bb5"}, + {file = "numpy-1.26.3.tar.gz", hash = "sha256:697df43e2b6310ecc9d95f05d5ef20eacc09c7c4ecc9da3f235d39e71b7da1e4"}, ] [[package]] @@ -7054,52 +7102,52 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0, [[package]] name = "transformers" -version = "4.35.2" +version = "4.37.2" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = true python-versions = ">=3.8.0" files = [ - {file = "transformers-4.35.2-py3-none-any.whl", hash = "sha256:9dfa76f8692379544ead84d98f537be01cd1070de75c74efb13abcbc938fbe2f"}, - {file = "transformers-4.35.2.tar.gz", hash = "sha256:2d125e197d77b0cdb6c9201df9fa7e2101493272e448b9fba9341c695bee2f52"}, + {file = "transformers-4.37.2-py3-none-any.whl", hash = "sha256:595a8b12a1fcc4ad0ced49ce206c58e17be68c85d7aee3d7546d04a32c910d2e"}, + {file = "transformers-4.37.2.tar.gz", hash = "sha256:f307082ae5d528b8480611a4879a4a11651012d0e9aaea3f6cf17219ffd95542"}, ] [package.dependencies] filelock = "*" -huggingface-hub = ">=0.16.4,<1.0" +huggingface-hub = ">=0.19.3,<1.0" numpy = ">=1.17" packaging = ">=20.0" pyyaml = ">=5.1" regex = "!=2019.12.17" requests = "*" -safetensors = ">=0.3.1" +safetensors = ">=0.4.1" tokenizers = ">=0.14,<0.19" tqdm = ">=4.27" [package.extras] -accelerate = ["accelerate (>=0.20.3)"] -agents = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.10,!=1.12.0)"] -all = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +accelerate = ["accelerate (>=0.21.0)"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.11,!=1.12.0)"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.11,!=1.12.0)", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] codecarbon = ["codecarbon (==1.2.0)"] -deepspeed = ["accelerate (>=0.20.3)", "deepspeed (>=0.9.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -docs = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.11,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.11,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.11,!=1.12.0)", "torchaudio", "torchvision"] docs-specific = ["hf-doc-builder"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] ftfy = ["ftfy"] -integrations = ["optuna", "ray[tune]", "sigopt"] +integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] modelcreation = ["cookiecutter (==1.7.3)"] -natten = ["natten (>=0.14.6)"] +natten = ["natten (>=0.14.6,<0.15.0)"] onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] optuna = ["optuna"] -quality = ["GitPython (<3.1.19)", "black (>=23.1,<24.0)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (>=0.0.241,<=0.0.259)", "urllib3 (<2.0.0)"] -ray = ["ray[tune]"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"] +ray = ["ray[tune] (>=2.7.0)"] retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] sagemaker = ["sagemaker (>=2.31.0)"] sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] @@ -7107,18 +7155,18 @@ serving = ["fastapi", "pydantic (<2)", "starlette", "uvicorn"] sigopt = ["sigopt"] sklearn = ["scikit-learn"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"] -tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx"] -tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] timm = ["timm"] tokenizers = ["tokenizers (>=0.14,<0.19)"] -torch = ["accelerate (>=0.20.3)", "torch (>=1.10,!=1.12.0)"] +torch = ["accelerate (>=0.21.0)", "torch (>=1.11,!=1.12.0)"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -torch-vision = ["Pillow (<10.0.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.16.4,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "tqdm (>=4.27)"] +torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch (>=1.11,!=1.12.0)", "tqdm (>=4.27)"] video = ["av (==9.2.0)", "decord (==0.6.0)"] -vision = ["Pillow (<10.0.0)"] +vision = ["Pillow (>=10.0.1,<=15.0)"] [[package]] name = "triton" @@ -8104,6 +8152,7 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] all = ["cohere", "detect-secrets", "email-reply-parser", "google-api-python-client", "google-auth-httplib2", "google-auth-oauthlib", "hdbscan", "langdetect", "langsmith", "llama-hub", "llama-index", "openai", "presidio_analyzer", "sentence-transformers", "spacy", "textacy", "umap-learn"] +bge = ["FlagEmbedding", "transformers"] cohere = ["cohere"] embeddings = ["cohere", "openai", "sentence-transformers"] github = ["llama-hub", "llama-index"] @@ -8122,4 +8171,4 @@ text-stats = ["textacy"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "3bba3fdb0ea5d7e2efdbaa63bae5d52245f761a8e65b13d55d679e5c0c68a4b2" +content-hash = "94b8c942172e5a02cb89c0fe8f7ea134169bdf5356d5486549bd8facb73ba8aa" diff --git a/pyproject.toml b/pyproject.toml index 991c91e9..60171251 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,8 @@ jinja2 = "^3.1.3" # Used for directory li cohere = { version = "^4.32", optional = true } openai = { version = "^1.7.1", optional = true } sentence-transformers = { version = "^2.2.2", optional = true } # SBERT on-device embeddings. +FlagEmbedding = { version = "^1.2.3", optional = true } # bge on-device embeddings. +transformers = { version = "^4.37.2", optional = true } # bge on-device embeddings. # Gmail source. email-reply-parser = { version = "^0.5.12", optional = true } @@ -86,6 +88,7 @@ hdbscan = { version = "^0.8.33", optional = true } [tool.poetry.extras] all = [ + "bge", "cohere", "detect-secrets", "email-reply-parser", @@ -131,6 +134,7 @@ text_stats = ["textacy"] # Text statistics. # Individual embeddings. gte = ["sentence-transformers"] +bge = ["FlagEmbedding", "transformers"] sbert = ["sentence-transformers"] cohere = ["cohere"] openai = ["openai"] diff --git a/web/lib/fastapi_client/models/ConceptSignal.ts b/web/lib/fastapi_client/models/ConceptSignal.ts index 4196e1b1..e67095f9 100644 --- a/web/lib/fastapi_client/models/ConceptSignal.ts +++ b/web/lib/fastapi_client/models/ConceptSignal.ts @@ -11,7 +11,7 @@ export type ConceptSignal = { /** * The name of the pre-computed embedding. */ - embedding: 'cohere' | 'sbert' | 'openai' | 'gte-tiny' | 'gte-small' | 'gte-base' | 'jina-v2-small' | 'jina-v2-base'; + embedding: 'cohere' | 'sbert' | 'openai' | 'gte-tiny' | 'gte-small' | 'gte-base' | 'jina-v2-small' | 'jina-v2-base' | 'bge-m3'; namespace: string; concept_name: string; version?: (number | null); diff --git a/web/lib/fastapi_client/models/SemanticSimilaritySignal.ts b/web/lib/fastapi_client/models/SemanticSimilaritySignal.ts index e159b80f..7adb8843 100644 --- a/web/lib/fastapi_client/models/SemanticSimilaritySignal.ts +++ b/web/lib/fastapi_client/models/SemanticSimilaritySignal.ts @@ -14,7 +14,7 @@ export type SemanticSimilaritySignal = { /** * The name of the pre-computed embedding. */ - embedding: 'cohere' | 'sbert' | 'openai' | 'gte-tiny' | 'gte-small' | 'gte-base' | 'jina-v2-small' | 'jina-v2-base'; + embedding: 'cohere' | 'sbert' | 'openai' | 'gte-tiny' | 'gte-small' | 'gte-base' | 'jina-v2-small' | 'jina-v2-base' | 'bge-m3'; query: string; /** * The input type of the query, used for the query embedding.