Skip to content

Commit

Permalink
Add bge-m3 embeddings.
Browse files Browse the repository at this point in the history
  • Loading branch information
nsthorat committed Feb 14, 2024
1 parent e41abe0 commit 20b80b1
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 28 deletions.
94 changes: 94 additions & 0 deletions lilac/embeddings/bge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Gegeral Text Embeddings (GTE) model. Open-source model, designed to run on device."""
import gc
from typing import TYPE_CHECKING, ClassVar, Iterator, Optional

from typing_extensions import override

if TYPE_CHECKING:
from FlagEmbedding import BGEM3FlagModel


import functools

from ..schema import Item
from ..signal import TextEmbeddingSignal
from ..splitters.spacy_splitter import clustering_spacy_chunker
from ..tasks import TaskExecutionType
from .embedding import chunked_compute_embedding
from .transformer_utils import SENTENCE_TRANSFORMER_BATCH_SIZE, setup_model_device

# See https://huggingface.co/spaces/mteb/leaderboard for leaderboard of models.
BGE_M3 = 'BAAI/bge-m3'


@functools.cache
def _get_and_cache_bge_m3(model_name: str) -> 'BGEM3FlagModel':
try:
from FlagEmbedding import BGEM3FlagModel
except ImportError:
raise ImportError(
'Could not import the "FlagEmbedding" python package. '
'Please install it with `pip install "lilac[bge]".'
)
model = BGEM3FlagModel(
'BAAI/bge-m3', use_fp16=True
) # Setting use_fp16 to True speeds up computation with a slight performance degradation
return model
return setup_model_device(model, model_name)


class BGEM3(TextEmbeddingSignal):
"""Computes BGE-M3 embeddings.
<br>This embedding runs on-device. See the [model card](https://huggingface.co/BAAI/bge-m3)
for details.
"""

name: ClassVar[str] = 'bge-m3'
display_name: ClassVar[str] = 'BGE-M3'
local_batch_size: ClassVar[int] = SENTENCE_TRANSFORMER_BATCH_SIZE
local_parallelism: ClassVar[int] = 1
local_strategy: ClassVar[TaskExecutionType] = 'threads'
supports_garden: ClassVar[bool] = False

_model_name = BGE_M3
_model: 'BGEM3FlagModel'

@override
def setup(self) -> None:
self._model = _get_and_cache_bge_m3(self._model_name)

@override
def compute(self, docs: list[str]) -> list[Optional[Item]]:
"""Call the embedding function."""

def _encode(doc):
# Extract the dense vectors from the model.
return self._model.encode(doc)['dense_vecs']

# While we get docs in batches of 1024, the chunker expands that by a factor of 3-10.
# The sentence transformer API actually does batching internally, so we pass
# local_batch_size * 16 to allow the library to see all the chunks at once.
return chunked_compute_embedding(
_encode, docs, self.local_batch_size * 16, chunker=clustering_spacy_chunker
)

@override
def compute_garden(self, docs: Iterator[str]) -> Iterator[Item]:
raise NotImplementedError('Garden computation is not supported for BGE-M3.')

@override
def teardown(self) -> None:
if not hasattr(self, '_model'):
return

self._model.cpu()
del self._model
gc.collect()

try:
import torch

torch.cuda.empty_cache()
except ImportError:
pass
3 changes: 3 additions & 0 deletions lilac/signals/default_signals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Registers all available default signals."""

from ..embeddings.bge import BGEM3
from ..embeddings.cohere import Cohere
from ..embeddings.gte import GTEBase, GTESmall, GTETiny
from ..embeddings.jina import JinaV2Base, JinaV2Small
Expand Down Expand Up @@ -43,3 +44,5 @@ def register_default_signals() -> None:

register_signal(JinaV2Small, exists_ok=True)
register_signal(JinaV2Base, exists_ok=True)

register_signal(BGEM3, exists_ok=True)
Empty file added notebooks/Test.ipynb
Empty file.
101 changes: 75 additions & 26 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ jinja2 = "^3.1.3" # Used for directory li
cohere = { version = "^4.32", optional = true }
openai = { version = "^1.7.1", optional = true }
sentence-transformers = { version = "^2.2.2", optional = true } # SBERT on-device embeddings.
FlagEmbedding = { version = "^1.2.3", optional = true } # bge on-device embeddings.
transformers = { version = "^4.37.2", optional = true } # bge on-device embeddings.

# Gmail source.
email-reply-parser = { version = "^0.5.12", optional = true }
Expand Down Expand Up @@ -86,6 +88,7 @@ hdbscan = { version = "^0.8.33", optional = true }

[tool.poetry.extras]
all = [
"bge",
"cohere",
"detect-secrets",
"email-reply-parser",
Expand Down Expand Up @@ -131,6 +134,7 @@ text_stats = ["textacy"] # Text statistics.

# Individual embeddings.
gte = ["sentence-transformers"]
bge = ["FlagEmbedding", "transformers"]
sbert = ["sentence-transformers"]
cohere = ["cohere"]
openai = ["openai"]
Expand Down
2 changes: 1 addition & 1 deletion web/lib/fastapi_client/models/ConceptSignal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export type ConceptSignal = {
/**
* The name of the pre-computed embedding.
*/
embedding: 'cohere' | 'sbert' | 'openai' | 'gte-tiny' | 'gte-small' | 'gte-base' | 'jina-v2-small' | 'jina-v2-base';
embedding: 'cohere' | 'sbert' | 'openai' | 'gte-tiny' | 'gte-small' | 'gte-base' | 'jina-v2-small' | 'jina-v2-base' | 'bge-m3';
namespace: string;
concept_name: string;
version?: (number | null);
Expand Down
2 changes: 1 addition & 1 deletion web/lib/fastapi_client/models/SemanticSimilaritySignal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export type SemanticSimilaritySignal = {
/**
* The name of the pre-computed embedding.
*/
embedding: 'cohere' | 'sbert' | 'openai' | 'gte-tiny' | 'gte-small' | 'gte-base' | 'jina-v2-small' | 'jina-v2-base';
embedding: 'cohere' | 'sbert' | 'openai' | 'gte-tiny' | 'gte-small' | 'gte-base' | 'jina-v2-small' | 'jina-v2-base' | 'bge-m3';
query: string;
/**
* The input type of the query, used for the query embedding.
Expand Down

0 comments on commit 20b80b1

Please sign in to comment.