Skip to content

Commit

Permalink
Modifications EmbeddingService
Browse files Browse the repository at this point in the history
  • Loading branch information
picaultj committed Dec 31, 2024
1 parent 0055648 commit 8c1c142
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 132 deletions.
15 changes: 10 additions & 5 deletions bertrend/demos/demos_utils/embed_documents_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def display_embed_documents_component() -> bool:
embedding_model_name = SessionStateManager.get("embedding_model_name")
if SessionStateManager.get("embedding_service_type", "local") == "local":
embedding_service = EmbeddingService(
local=True, embedding_model_name=embedding_model_name
local=True,
model_name=embedding_model_name,
embedding_dtype=embedding_dtype,
)
else:
embedding_service = EmbeddingService(
Expand All @@ -32,14 +34,17 @@ def display_embed_documents_component() -> bool:
TEXT_COLUMN
].tolist()

embedding_model, embeddings = embedding_service.embed_documents(
texts=texts,
embedding_model_name=embedding_model_name,
embedding_dtype=embedding_dtype,
embedding_model, embeddings, token_strings, token_embeddings = (
embedding_service.embed(
texts=texts,
)
)

SessionStateManager.set("embedding_model", embedding_model)
SessionStateManager.set("embeddings", embeddings)
SessionStateManager.set("token_strings", token_strings)
SessionStateManager.set("token_embeddings", token_embeddings)

SessionStateManager.set("data_embedded", True)

st.success(EMBEDDINGS_CALCULATED_MESSAGE, icon=SUCCESS_ICON)
1 change: 0 additions & 1 deletion bertrend/demos/weak_signals/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
display_bertopic_hyperparameters,
display_bertrend_hyperparameters,
)
from bertrend.services.embedding_service import EmbeddingService
from bertrend.topic_model import TopicModel
from bertrend.demos.weak_signals.messages import (
MODEL_MERGING_COMPLETE_MESSAGE,
Expand Down
194 changes: 153 additions & 41 deletions bertrend/services/embedding_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# SPDX-License-Identifier: MPL-2.0
# This file is part of BERTrend.
import json
from typing import List, Tuple
from typing import List, Tuple, Union

import numpy as np
import pandas as pd
import requests
import torch
from bertopic.backend import BaseEmbedder
from loguru import logger
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
Expand All @@ -20,27 +22,30 @@
)


class EmbeddingService:
class EmbeddingService(BaseEmbedder):
def __init__(
self,
local: bool = True,
embedding_model_name: str = None,
model_name: str = None,
embedding_dtype: str = None,
host: str = EMBEDDING_CONFIG["host"],
port: str = EMBEDDING_CONFIG["port"],
):
super().__init__()
self.local = local
if not self.local:
self.url = f"http://{host}:{port}"
self.embedding_model_name = embedding_model_name

# TODO: harmonize interfaces for local / remote services
self.embedding_model = None
self.embedding_model_name = model_name
self.embedding_dtype = embedding_dtype

def embed_documents(
self,
texts: List[str],
embedding_model_name: str,
embedding_dtype: str,
) -> Tuple[str | SentenceTransformer, np.ndarray]:
def embed(self, texts: Union[List[str], pd.Series], verbose: bool = False) -> Tuple[
str | SentenceTransformer,
np.ndarray,
List[List[str]] | None,
List[np.ndarray] | None,
]:
"""
Embed a list of documents using a Sentence Transformer model.
Expand All @@ -49,42 +54,39 @@ def embed_documents(
memory efficiently, especially for large datasets.
Args:
texts (List[str]): A list of text documents to be embedded.
texts (Union[List[str], pd.Series]): A list of text documents to be embedded.
embedding_model_name (str): The name of the Sentence Transformer model to use.
embedding_dtype (str): The data type to use for the embeddings ('float32', 'float16', or 'bfloat16').
embedding_device (str, optional): The device to use for embedding ('cuda' or 'cpu').
Defaults to 'cuda' if available, else 'cpu'.
batch_size (int, optional): The number of texts to process in each batch. Defaults to 32.
batch_size (i nt, optional): The number of texts to process in each batch. Defaults to 32.
max_seq_length (int, optional): The maximum sequence length for the model. Defaults to 512.
Returns:
Tuple[SentenceTransformer, np.ndarray]: A tuple containing:
- The loaded and configured Sentence Transformer model.
- A numpy array of embeddings, where each row corresponds to a text in the input list.
- List[List[str]] | None,
- List[np.ndarray] | None
Raises:
ValueError: If an invalid embedding_dtype is provided.
"""
# Convert to list if input is a pandas Series
if isinstance(texts, pd.Series):
texts = texts.tolist()
if self.local:
return self._local_embed_documents(
texts,
embedding_model_name,
embedding_dtype,
)
return self._local_embed_documents(texts)
else:
return self._remote_embed_documents(
texts,
)
return self._remote_embed_documents(texts)

def _local_embed_documents(
self,
texts: List[str],
embedding_model_name: str,
embedding_dtype: str,
embedding_device: str = EMBEDDING_DEVICE,
batch_size: int = EMBEDDING_BATCH_SIZE,
max_seq_length: int = EMBEDDING_MAX_SEQ_LENGTH,
) -> Tuple[SentenceTransformer, np.ndarray]:
) -> Tuple[SentenceTransformer, np.ndarray, List[List[str]], List[np.ndarray]]:
"""
Embed a list of documents using a Sentence Transformer model.
Expand All @@ -111,23 +113,27 @@ def _local_embed_documents(
"""
# Configure model kwargs based on the specified dtype
model_kwargs = {}
if embedding_dtype == "float16":
if self.embedding_dtype == "float16":
model_kwargs["torch_dtype"] = torch.float16
elif embedding_dtype == "bfloat16":
elif self.embedding_dtype == "bfloat16":
model_kwargs["torch_dtype"] = torch.bfloat16
elif embedding_dtype != "float32":
elif self.embedding_dtype != "float32":
raise ValueError(
"Invalid embedding_dtype. Must be 'float32', 'float16', or 'bfloat16'."
)

# Load the embedding model
embedding_model = SentenceTransformer(
embedding_model_name,
device=embedding_device,
trust_remote_code=True,
model_kwargs=model_kwargs,
)
embedding_model.max_seq_length = max_seq_length
if self.embedding_model is None:
logger.info(f"Loading embedding model: {self.embedding_model_name}...")
# Load the embedding model
self.embedding_model = SentenceTransformer(
self.embedding_model_name,
device=embedding_device,
trust_remote_code=True,
model_kwargs=model_kwargs,
)
self.embedding_model.max_seq_length = max_seq_length
self.batch_size = batch_size
logger.debug("Embedding model loaded")

# Calculate the number of batches
num_batches = (len(texts) + batch_size - 1) // batch_size
Expand All @@ -140,19 +146,42 @@ def _local_embed_documents(
start_idx = i * batch_size
end_idx = min(start_idx + batch_size, len(texts))
batch_texts = texts[start_idx:end_idx]
batch_embeddings = embedding_model.encode(
batch_texts, show_progress_bar=False
batch_embeddings = self.embedding_model.encode(
batch_texts,
show_progress_bar=False,
output_value=None, # to get all output values, not only sentence embeddings
)
embeddings.append(batch_embeddings)

# Concatenate all batch embeddings
embeddings = np.concatenate(embeddings, axis=0)
all_embeddings = np.concatenate(embeddings, axis=0)
logger.success(f"Embedded {len(texts)} documents in {num_batches} batches")

embeddings = [
item["sentence_embedding"].detach().cpu() for item in all_embeddings
]
embeddings = torch.stack(embeddings)
embeddings = embeddings.numpy()

token_embeddings = [
item["token_embeddings"].detach().cpu() for item in all_embeddings
]
token_ids = [item["input_ids"].detach().cpu() for item in all_embeddings]

token_embeddings = convert_to_numpy(token_embeddings)
token_ids = convert_to_numpy(token_ids, type="token_id")

tokenizer = self.embedding_model.embedding_model._first_module().tokenizer

token_strings, token_embeddings = group_tokens(
tokenizer, token_ids, token_embeddings, language="french"
)

return embedding_model, embeddings
return self.embedding_model, embeddings, token_strings, token_embeddings

def _remote_embed_documents(
self, texts: List[str], show_progress_bar: bool = True
) -> Tuple[str, np.ndarray]:
) -> Tuple[str, np.ndarray, None, None]:
"""
Embed a list of documents using a Sentence Transformer model.
Expand Down Expand Up @@ -185,7 +214,7 @@ def _remote_embed_documents(
if response.status_code == 200:
embeddings = np.array(response.json()["embeddings"])
logger.debug(f"Computing embeddings done for batch")
return self._get_remote_model_name(), embeddings
return self._get_remote_model_name(), embeddings, None, None
else:
logger.error(f"Error: {response.status_code}")
raise Exception(f"Error: {response.status_code}")
Expand All @@ -204,3 +233,86 @@ def _get_remote_model_name(self) -> str:
else:
logger.error(f"Error: {response.status_code}")
raise Exception(f"Error: {response.status_code}")


def convert_to_numpy(obj, type=None):
"""
Convert a torch.Tensor or list of torch.Tensors to numpy arrays.
Args:
obj: The object to convert (torch.Tensor or list).
type: The type of conversion (optional, used for token ids).
Returns:
np.ndarray or list of np.ndarray.
"""
if isinstance(obj, torch.Tensor):
return (
obj.numpy().astype(np.int64)
if type == "token_id"
else obj.numpy().astype(np.float32)
)
elif isinstance(obj, list):
return [convert_to_numpy(item) for item in obj]
else:
raise TypeError("Object must be a list or torch.Tensor")


def group_tokens(tokenizer, token_ids, token_embeddings, language="french"):
"""
Group split tokens into whole words and average their embeddings.
Args:
tokenizer: The tokenizer to use for converting ids to tokens.
token_ids: List of token ids.
token_embeddings: List of token embeddings.
language: The language of the tokens (default is "french").
Returns:
List of grouped tokens and their corresponding embeddings.
"""
grouped_token_lists = []
grouped_embedding_lists = []

special_tokens = {
"english": ["[CLS]", "[SEP]", "[PAD]"],
"french": ["<s>", "</s>", "<pad>"],
}
subword_prefix = {"english": "##", "french": "▁"}

for token_id, token_embedding in tqdm(
zip(token_ids, token_embeddings), desc="Grouping split tokens into whole words"
):
tokens = tokenizer.convert_ids_to_tokens(token_id)

grouped_tokens = []
grouped_embeddings = []
current_word = ""
current_embedding = []

for token, embedding in zip(tokens, token_embedding):
if token in special_tokens[language]:
continue

if language == "french" and token.startswith(subword_prefix[language]):
if current_word:
grouped_tokens.append(current_word)
grouped_embeddings.append(np.mean(current_embedding, axis=0))
current_word = token[1:]
current_embedding = [embedding]
elif language == "english" and not token.startswith(
subword_prefix[language]
):
if current_word:
grouped_tokens.append(current_word)
grouped_embeddings.append(np.mean(current_embedding, axis=0))
current_word = token
current_embedding = [embedding]
else:
current_word += token.lstrip(subword_prefix[language])
current_embedding.append(embedding)

if current_word:
grouped_tokens.append(current_word)
grouped_embeddings.append(np.mean(current_embedding, axis=0))

grouped_token_lists.append(grouped_tokens)
grouped_embedding_lists.append(np.array(grouped_embeddings))

return grouped_token_lists, grouped_embedding_lists
Loading

0 comments on commit 8c1c142

Please sign in to comment.