Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Per-token embeddings #423

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions fastembed/late_interaction/token_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Union, Iterable, Optional, List, Dict, Any, Type

import numpy as np

from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.late_interaction.late_interaction_embedding_base import LateInteractionTextEmbeddingBase
from fastembed.text.onnx_embedding import OnnxTextEmbedding
from fastembed.text.onnx_text_model import TextEmbeddingWorker

supported_token_embeddings_models = [
{
"model": "jinaai/jina-embeddings-v2-small-en-tokens",
"dim": 512,
"description": "Text embeddings, Unimodal (text), English, 8192 input tokens truncation,"
" Prefixes for queries/documents: not necessary, 2023 year.",
"license": "apache-2.0",
"size_in_GB": 0.12,
"sources": {"hf": "xenova/jina-embeddings-v2-small-en"},
"model_file": "onnx/model.onnx",
},
]


class TokenEmbeddingsModel(OnnxTextEmbedding, LateInteractionTextEmbeddingBase):
@classmethod
def list_supported_models(cls) -> List[Dict[str, Any]]:
"""Lists the supported models.

Returns:
List[Dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_token_embeddings_models

@classmethod
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
return TokensEmbeddingWorker

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
# Size: (batch_size, sequence_length, hidden_size)
embeddings = output.model_output
# Size: (batch_size, sequence_length)
masks = output.attention_mask

# For each document we only select those embeddings that are not masked out

for i in range(embeddings.shape[0]):
yield embeddings[i, masks[i] == 1]

def embed(
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
**kwargs,
) -> Iterable[np.ndarray]:
yield from OnnxTextEmbedding.embed(self, documents, batch_size=batch_size, parallel=parallel, **kwargs)

def tokenize_docs(self, documents: List[str]) -> List[np.ndarray]:
encoded = self.tokenizer.encode_batch(documents)
return [e.ids for e in encoded]


class TokensEmbeddingWorker(TextEmbeddingWorker):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> TokenEmbeddingsModel:
return TokenEmbeddingsModel(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)


if __name__ == "__main__":
# Example usage
model = TokenEmbeddingsModel(model_name="jinaai/jina-embeddings-v2-small-en-tokens")
docs = ["Hello, world!", "hello", "hello hello"]

embeddings = model.embed(docs)
for emb in embeddings:
print(emb.shape)

print(model.tokenize_docs(docs))
Loading