diff --git a/bertrend/demos/demos_utils/embed_documents_component.py b/bertrend/demos/demos_utils/embed_documents_component.py
index 0d7cf1b..b92e336 100644
--- a/bertrend/demos/demos_utils/embed_documents_component.py
+++ b/bertrend/demos/demos_utils/embed_documents_component.py
@@ -19,7 +19,9 @@ def display_embed_documents_component() -> bool:
embedding_model_name = SessionStateManager.get("embedding_model_name")
if SessionStateManager.get("embedding_service_type", "local") == "local":
embedding_service = EmbeddingService(
- local=True, embedding_model_name=embedding_model_name
+ local=True,
+ model_name=embedding_model_name,
+ embedding_dtype=embedding_dtype,
)
else:
embedding_service = EmbeddingService(
@@ -32,14 +34,17 @@ def display_embed_documents_component() -> bool:
TEXT_COLUMN
].tolist()
- embedding_model, embeddings = embedding_service.embed_documents(
- texts=texts,
- embedding_model_name=embedding_model_name,
- embedding_dtype=embedding_dtype,
+ embedding_model, embeddings, token_strings, token_embeddings = (
+ embedding_service.embed(
+ texts=texts,
+ )
)
SessionStateManager.set("embedding_model", embedding_model)
SessionStateManager.set("embeddings", embeddings)
+ SessionStateManager.set("token_strings", token_strings)
+ SessionStateManager.set("token_embeddings", token_embeddings)
+
SessionStateManager.set("data_embedded", True)
st.success(EMBEDDINGS_CALCULATED_MESSAGE, icon=SUCCESS_ICON)
diff --git a/bertrend/demos/weak_signals/app.py b/bertrend/demos/weak_signals/app.py
index a0968a0..0b6d52e 100644
--- a/bertrend/demos/weak_signals/app.py
+++ b/bertrend/demos/weak_signals/app.py
@@ -41,7 +41,6 @@
display_bertopic_hyperparameters,
display_bertrend_hyperparameters,
)
-from bertrend.services.embedding_service import EmbeddingService
from bertrend.topic_model import TopicModel
from bertrend.demos.weak_signals.messages import (
MODEL_MERGING_COMPLETE_MESSAGE,
diff --git a/bertrend/services/embedding_service.py b/bertrend/services/embedding_service.py
index 612037e..343f1a2 100644
--- a/bertrend/services/embedding_service.py
+++ b/bertrend/services/embedding_service.py
@@ -3,11 +3,13 @@
# SPDX-License-Identifier: MPL-2.0
# This file is part of BERTrend.
import json
-from typing import List, Tuple
+from typing import List, Tuple, Union
import numpy as np
+import pandas as pd
import requests
import torch
+from bertopic.backend import BaseEmbedder
from loguru import logger
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
@@ -20,27 +22,30 @@
)
-class EmbeddingService:
+class EmbeddingService(BaseEmbedder):
def __init__(
self,
local: bool = True,
- embedding_model_name: str = None,
+ model_name: str = None,
+ embedding_dtype: str = None,
host: str = EMBEDDING_CONFIG["host"],
port: str = EMBEDDING_CONFIG["port"],
):
+ super().__init__()
self.local = local
if not self.local:
self.url = f"http://{host}:{port}"
- self.embedding_model_name = embedding_model_name
- # TODO: harmonize interfaces for local / remote services
+ self.embedding_model = None
+ self.embedding_model_name = model_name
+ self.embedding_dtype = embedding_dtype
- def embed_documents(
- self,
- texts: List[str],
- embedding_model_name: str,
- embedding_dtype: str,
- ) -> Tuple[str | SentenceTransformer, np.ndarray]:
+ def embed(self, texts: Union[List[str], pd.Series], verbose: bool = False) -> Tuple[
+ str | SentenceTransformer,
+ np.ndarray,
+ List[List[str]] | None,
+ List[np.ndarray] | None,
+ ]:
"""
Embed a list of documents using a Sentence Transformer model.
@@ -49,42 +54,39 @@ def embed_documents(
memory efficiently, especially for large datasets.
Args:
- texts (List[str]): A list of text documents to be embedded.
+ texts (Union[List[str], pd.Series]): A list of text documents to be embedded.
embedding_model_name (str): The name of the Sentence Transformer model to use.
embedding_dtype (str): The data type to use for the embeddings ('float32', 'float16', or 'bfloat16').
embedding_device (str, optional): The device to use for embedding ('cuda' or 'cpu').
Defaults to 'cuda' if available, else 'cpu'.
- batch_size (int, optional): The number of texts to process in each batch. Defaults to 32.
+ batch_size (i nt, optional): The number of texts to process in each batch. Defaults to 32.
max_seq_length (int, optional): The maximum sequence length for the model. Defaults to 512.
Returns:
Tuple[SentenceTransformer, np.ndarray]: A tuple containing:
- The loaded and configured Sentence Transformer model.
- A numpy array of embeddings, where each row corresponds to a text in the input list.
+ - List[List[str]] | None,
+ - List[np.ndarray] | None
Raises:
ValueError: If an invalid embedding_dtype is provided.
"""
+ # Convert to list if input is a pandas Series
+ if isinstance(texts, pd.Series):
+ texts = texts.tolist()
if self.local:
- return self._local_embed_documents(
- texts,
- embedding_model_name,
- embedding_dtype,
- )
+ return self._local_embed_documents(texts)
else:
- return self._remote_embed_documents(
- texts,
- )
+ return self._remote_embed_documents(texts)
def _local_embed_documents(
self,
texts: List[str],
- embedding_model_name: str,
- embedding_dtype: str,
embedding_device: str = EMBEDDING_DEVICE,
batch_size: int = EMBEDDING_BATCH_SIZE,
max_seq_length: int = EMBEDDING_MAX_SEQ_LENGTH,
- ) -> Tuple[SentenceTransformer, np.ndarray]:
+ ) -> Tuple[SentenceTransformer, np.ndarray, List[List[str]], List[np.ndarray]]:
"""
Embed a list of documents using a Sentence Transformer model.
@@ -111,23 +113,27 @@ def _local_embed_documents(
"""
# Configure model kwargs based on the specified dtype
model_kwargs = {}
- if embedding_dtype == "float16":
+ if self.embedding_dtype == "float16":
model_kwargs["torch_dtype"] = torch.float16
- elif embedding_dtype == "bfloat16":
+ elif self.embedding_dtype == "bfloat16":
model_kwargs["torch_dtype"] = torch.bfloat16
- elif embedding_dtype != "float32":
+ elif self.embedding_dtype != "float32":
raise ValueError(
"Invalid embedding_dtype. Must be 'float32', 'float16', or 'bfloat16'."
)
- # Load the embedding model
- embedding_model = SentenceTransformer(
- embedding_model_name,
- device=embedding_device,
- trust_remote_code=True,
- model_kwargs=model_kwargs,
- )
- embedding_model.max_seq_length = max_seq_length
+ if self.embedding_model is None:
+ logger.info(f"Loading embedding model: {self.embedding_model_name}...")
+ # Load the embedding model
+ self.embedding_model = SentenceTransformer(
+ self.embedding_model_name,
+ device=embedding_device,
+ trust_remote_code=True,
+ model_kwargs=model_kwargs,
+ )
+ self.embedding_model.max_seq_length = max_seq_length
+ self.batch_size = batch_size
+ logger.debug("Embedding model loaded")
# Calculate the number of batches
num_batches = (len(texts) + batch_size - 1) // batch_size
@@ -140,19 +146,42 @@ def _local_embed_documents(
start_idx = i * batch_size
end_idx = min(start_idx + batch_size, len(texts))
batch_texts = texts[start_idx:end_idx]
- batch_embeddings = embedding_model.encode(
- batch_texts, show_progress_bar=False
+ batch_embeddings = self.embedding_model.encode(
+ batch_texts,
+ show_progress_bar=False,
+ output_value=None, # to get all output values, not only sentence embeddings
)
embeddings.append(batch_embeddings)
# Concatenate all batch embeddings
- embeddings = np.concatenate(embeddings, axis=0)
+ all_embeddings = np.concatenate(embeddings, axis=0)
+ logger.success(f"Embedded {len(texts)} documents in {num_batches} batches")
+
+ embeddings = [
+ item["sentence_embedding"].detach().cpu() for item in all_embeddings
+ ]
+ embeddings = torch.stack(embeddings)
+ embeddings = embeddings.numpy()
+
+ token_embeddings = [
+ item["token_embeddings"].detach().cpu() for item in all_embeddings
+ ]
+ token_ids = [item["input_ids"].detach().cpu() for item in all_embeddings]
+
+ token_embeddings = convert_to_numpy(token_embeddings)
+ token_ids = convert_to_numpy(token_ids, type="token_id")
+
+ tokenizer = self.embedding_model._first_module().tokenizer
+
+ token_strings, token_embeddings = group_tokens(
+ tokenizer, token_ids, token_embeddings, language="french"
+ )
- return embedding_model, embeddings
+ return self.embedding_model, embeddings, token_strings, token_embeddings
def _remote_embed_documents(
self, texts: List[str], show_progress_bar: bool = True
- ) -> Tuple[str, np.ndarray]:
+ ) -> Tuple[str, np.ndarray, None, None]:
"""
Embed a list of documents using a Sentence Transformer model.
@@ -185,7 +214,7 @@ def _remote_embed_documents(
if response.status_code == 200:
embeddings = np.array(response.json()["embeddings"])
logger.debug(f"Computing embeddings done for batch")
- return self._get_remote_model_name(), embeddings
+ return self._get_remote_model_name(), embeddings, None, None
else:
logger.error(f"Error: {response.status_code}")
raise Exception(f"Error: {response.status_code}")
@@ -204,3 +233,86 @@ def _get_remote_model_name(self) -> str:
else:
logger.error(f"Error: {response.status_code}")
raise Exception(f"Error: {response.status_code}")
+
+
+def convert_to_numpy(obj, type=None):
+ """
+ Convert a torch.Tensor or list of torch.Tensors to numpy arrays.
+ Args:
+ obj: The object to convert (torch.Tensor or list).
+ type: The type of conversion (optional, used for token ids).
+ Returns:
+ np.ndarray or list of np.ndarray.
+ """
+ if isinstance(obj, torch.Tensor):
+ return (
+ obj.numpy().astype(np.int64)
+ if type == "token_id"
+ else obj.numpy().astype(np.float32)
+ )
+ elif isinstance(obj, list):
+ return [convert_to_numpy(item) for item in obj]
+ else:
+ raise TypeError("Object must be a list or torch.Tensor")
+
+
+def group_tokens(tokenizer, token_ids, token_embeddings, language="french"):
+ """
+ Group split tokens into whole words and average their embeddings.
+ Args:
+ tokenizer: The tokenizer to use for converting ids to tokens.
+ token_ids: List of token ids.
+ token_embeddings: List of token embeddings.
+ language: The language of the tokens (default is "french").
+ Returns:
+ List of grouped tokens and their corresponding embeddings.
+ """
+ grouped_token_lists = []
+ grouped_embedding_lists = []
+
+ special_tokens = {
+ "english": ["[CLS]", "[SEP]", "[PAD]"],
+ "french": ["", "", ""],
+ }
+ subword_prefix = {"english": "##", "french": "▁"}
+
+ for token_id, token_embedding in tqdm(
+ zip(token_ids, token_embeddings), desc="Grouping split tokens into whole words"
+ ):
+ tokens = tokenizer.convert_ids_to_tokens(token_id)
+
+ grouped_tokens = []
+ grouped_embeddings = []
+ current_word = ""
+ current_embedding = []
+
+ for token, embedding in zip(tokens, token_embedding):
+ if token in special_tokens[language]:
+ continue
+
+ if language == "french" and token.startswith(subword_prefix[language]):
+ if current_word:
+ grouped_tokens.append(current_word)
+ grouped_embeddings.append(np.mean(current_embedding, axis=0))
+ current_word = token[1:]
+ current_embedding = [embedding]
+ elif language == "english" and not token.startswith(
+ subword_prefix[language]
+ ):
+ if current_word:
+ grouped_tokens.append(current_word)
+ grouped_embeddings.append(np.mean(current_embedding, axis=0))
+ current_word = token
+ current_embedding = [embedding]
+ else:
+ current_word += token.lstrip(subword_prefix[language])
+ current_embedding.append(embedding)
+
+ if current_word:
+ grouped_tokens.append(current_word)
+ grouped_embeddings.append(np.mean(current_embedding, axis=0))
+
+ grouped_token_lists.append(grouped_tokens)
+ grouped_embedding_lists.append(np.array(grouped_embeddings))
+
+ return grouped_token_lists, grouped_embedding_lists
diff --git a/bertrend/train.py b/bertrend/train.py
index 083cca0..6d8c9c5 100644
--- a/bertrend/train.py
+++ b/bertrend/train.py
@@ -25,6 +25,7 @@
from bertrend import BASE_CACHE_PATH, LLM_CONFIG
from bertrend.parameters import STOPWORDS
from bertrend.llm_utils.openai_client import OpenAI_Client
+from bertrend.services.embedding_service import convert_to_numpy, group_tokens
from bertrend.utils.data_loading import TEXT_COLUMN
from bertrend.llm_utils.prompts import BERTOPIC_FRENCH_TOPIC_REPRESENTATION_PROMPT
@@ -109,89 +110,6 @@ def embed(self, documents: Union[List[str], pd.Series], verbose=True) -> np.ndar
return all_embeddings
-def convert_to_numpy(obj, type=None):
- """
- Convert a torch.Tensor or list of torch.Tensors to numpy arrays.
- Args:
- obj: The object to convert (torch.Tensor or list).
- type: The type of conversion (optional, used for token ids).
- Returns:
- np.ndarray or list of np.ndarray.
- """
- if isinstance(obj, torch.Tensor):
- return (
- obj.numpy().astype(np.int64)
- if type == "token_id"
- else obj.numpy().astype(np.float32)
- )
- elif isinstance(obj, list):
- return [convert_to_numpy(item) for item in obj]
- else:
- raise TypeError("Object must be a list or torch.Tensor")
-
-
-def group_tokens(tokenizer, token_ids, token_embeddings, language="french"):
- """
- Group split tokens into whole words and average their embeddings.
- Args:
- tokenizer: The tokenizer to use for converting ids to tokens.
- token_ids: List of token ids.
- token_embeddings: List of token embeddings.
- language: The language of the tokens (default is "french").
- Returns:
- List of grouped tokens and their corresponding embeddings.
- """
- grouped_token_lists = []
- grouped_embedding_lists = []
-
- special_tokens = {
- "english": ["[CLS]", "[SEP]", "[PAD]"],
- "french": ["", "", ""],
- }
- subword_prefix = {"english": "##", "french": "▁"}
-
- for token_id, token_embedding in tqdm(
- zip(token_ids, token_embeddings), desc="Grouping split tokens into whole words"
- ):
- tokens = tokenizer.convert_ids_to_tokens(token_id)
-
- grouped_tokens = []
- grouped_embeddings = []
- current_word = ""
- current_embedding = []
-
- for token, embedding in zip(tokens, token_embedding):
- if token in special_tokens[language]:
- continue
-
- if language == "french" and token.startswith(subword_prefix[language]):
- if current_word:
- grouped_tokens.append(current_word)
- grouped_embeddings.append(np.mean(current_embedding, axis=0))
- current_word = token[1:]
- current_embedding = [embedding]
- elif language == "english" and not token.startswith(
- subword_prefix[language]
- ):
- if current_word:
- grouped_tokens.append(current_word)
- grouped_embeddings.append(np.mean(current_embedding, axis=0))
- current_word = token
- current_embedding = [embedding]
- else:
- current_word += token.lstrip(subword_prefix[language])
- current_embedding.append(embedding)
-
- if current_word:
- grouped_tokens.append(current_word)
- grouped_embeddings.append(np.mean(current_embedding, axis=0))
-
- grouped_token_lists.append(grouped_tokens)
- grouped_embedding_lists.append(np.array(grouped_embeddings))
-
- return grouped_token_lists, grouped_embedding_lists
-
-
def remove_special_tokens(tokenizer, token_id, token_embedding, special_tokens):
"""
Remove special tokens from the token ids and embeddings.
diff --git a/bertrend_apps/newsletters/__main__.py b/bertrend_apps/newsletters/__main__.py
index 3fc7b00..e4fc59c 100644
--- a/bertrend_apps/newsletters/__main__.py
+++ b/bertrend_apps/newsletters/__main__.py
@@ -24,6 +24,7 @@
from bertrend import FEED_BASE_PATH, BEST_CUDA_DEVICE, OUTPUT_PATH
from bertrend.parameters import BERTOPIC_SERIALIZATION
+from bertrend.services.embedding_service import EmbeddingService
from bertrend.utils.config_utils import load_toml_config
from bertrend.utils.data_loading import (
enhanced_split_df_by_paragraphs,
@@ -35,7 +36,7 @@
generate_newsletter,
export_md_string,
)
-from bertrend.train import EmbeddingModel, train_BERTopic
+from bertrend.train import train_BERTopic
from bertrend_apps.common.mail_utils import get_credentials, send_email
from bertrend_apps.common.crontab_utils import schedule_newsletter
@@ -122,7 +123,7 @@ def newsletter_from_feed(
topic_model = _load_topic_model(model_path)
logger.info(f"Topic model loaded from {model_path}")
logger.info("Computation of embeddings for new data...")
- embeddings = EmbeddingModel(
+ embeddings = EmbeddingService(
config.get("topic_model.embedding", "model_name")
).embed(dataset[TEXT_COLUMN])
topics, probs = topic_model.transform(dataset[TEXT_COLUMN], embeddings)