diff --git a/bertrend/demos/demos_utils/embed_documents_component.py b/bertrend/demos/demos_utils/embed_documents_component.py index 0d7cf1b..b92e336 100644 --- a/bertrend/demos/demos_utils/embed_documents_component.py +++ b/bertrend/demos/demos_utils/embed_documents_component.py @@ -19,7 +19,9 @@ def display_embed_documents_component() -> bool: embedding_model_name = SessionStateManager.get("embedding_model_name") if SessionStateManager.get("embedding_service_type", "local") == "local": embedding_service = EmbeddingService( - local=True, embedding_model_name=embedding_model_name + local=True, + model_name=embedding_model_name, + embedding_dtype=embedding_dtype, ) else: embedding_service = EmbeddingService( @@ -32,14 +34,17 @@ def display_embed_documents_component() -> bool: TEXT_COLUMN ].tolist() - embedding_model, embeddings = embedding_service.embed_documents( - texts=texts, - embedding_model_name=embedding_model_name, - embedding_dtype=embedding_dtype, + embedding_model, embeddings, token_strings, token_embeddings = ( + embedding_service.embed( + texts=texts, + ) ) SessionStateManager.set("embedding_model", embedding_model) SessionStateManager.set("embeddings", embeddings) + SessionStateManager.set("token_strings", token_strings) + SessionStateManager.set("token_embeddings", token_embeddings) + SessionStateManager.set("data_embedded", True) st.success(EMBEDDINGS_CALCULATED_MESSAGE, icon=SUCCESS_ICON) diff --git a/bertrend/demos/weak_signals/app.py b/bertrend/demos/weak_signals/app.py index a0968a0..0b6d52e 100644 --- a/bertrend/demos/weak_signals/app.py +++ b/bertrend/demos/weak_signals/app.py @@ -41,7 +41,6 @@ display_bertopic_hyperparameters, display_bertrend_hyperparameters, ) -from bertrend.services.embedding_service import EmbeddingService from bertrend.topic_model import TopicModel from bertrend.demos.weak_signals.messages import ( MODEL_MERGING_COMPLETE_MESSAGE, diff --git a/bertrend/services/embedding_service.py b/bertrend/services/embedding_service.py index 612037e..343f1a2 100644 --- a/bertrend/services/embedding_service.py +++ b/bertrend/services/embedding_service.py @@ -3,11 +3,13 @@ # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. import json -from typing import List, Tuple +from typing import List, Tuple, Union import numpy as np +import pandas as pd import requests import torch +from bertopic.backend import BaseEmbedder from loguru import logger from sentence_transformers import SentenceTransformer from tqdm import tqdm @@ -20,27 +22,30 @@ ) -class EmbeddingService: +class EmbeddingService(BaseEmbedder): def __init__( self, local: bool = True, - embedding_model_name: str = None, + model_name: str = None, + embedding_dtype: str = None, host: str = EMBEDDING_CONFIG["host"], port: str = EMBEDDING_CONFIG["port"], ): + super().__init__() self.local = local if not self.local: self.url = f"http://{host}:{port}" - self.embedding_model_name = embedding_model_name - # TODO: harmonize interfaces for local / remote services + self.embedding_model = None + self.embedding_model_name = model_name + self.embedding_dtype = embedding_dtype - def embed_documents( - self, - texts: List[str], - embedding_model_name: str, - embedding_dtype: str, - ) -> Tuple[str | SentenceTransformer, np.ndarray]: + def embed(self, texts: Union[List[str], pd.Series], verbose: bool = False) -> Tuple[ + str | SentenceTransformer, + np.ndarray, + List[List[str]] | None, + List[np.ndarray] | None, + ]: """ Embed a list of documents using a Sentence Transformer model. @@ -49,42 +54,39 @@ def embed_documents( memory efficiently, especially for large datasets. Args: - texts (List[str]): A list of text documents to be embedded. + texts (Union[List[str], pd.Series]): A list of text documents to be embedded. embedding_model_name (str): The name of the Sentence Transformer model to use. embedding_dtype (str): The data type to use for the embeddings ('float32', 'float16', or 'bfloat16'). embedding_device (str, optional): The device to use for embedding ('cuda' or 'cpu'). Defaults to 'cuda' if available, else 'cpu'. - batch_size (int, optional): The number of texts to process in each batch. Defaults to 32. + batch_size (i nt, optional): The number of texts to process in each batch. Defaults to 32. max_seq_length (int, optional): The maximum sequence length for the model. Defaults to 512. Returns: Tuple[SentenceTransformer, np.ndarray]: A tuple containing: - The loaded and configured Sentence Transformer model. - A numpy array of embeddings, where each row corresponds to a text in the input list. + - List[List[str]] | None, + - List[np.ndarray] | None Raises: ValueError: If an invalid embedding_dtype is provided. """ + # Convert to list if input is a pandas Series + if isinstance(texts, pd.Series): + texts = texts.tolist() if self.local: - return self._local_embed_documents( - texts, - embedding_model_name, - embedding_dtype, - ) + return self._local_embed_documents(texts) else: - return self._remote_embed_documents( - texts, - ) + return self._remote_embed_documents(texts) def _local_embed_documents( self, texts: List[str], - embedding_model_name: str, - embedding_dtype: str, embedding_device: str = EMBEDDING_DEVICE, batch_size: int = EMBEDDING_BATCH_SIZE, max_seq_length: int = EMBEDDING_MAX_SEQ_LENGTH, - ) -> Tuple[SentenceTransformer, np.ndarray]: + ) -> Tuple[SentenceTransformer, np.ndarray, List[List[str]], List[np.ndarray]]: """ Embed a list of documents using a Sentence Transformer model. @@ -111,23 +113,27 @@ def _local_embed_documents( """ # Configure model kwargs based on the specified dtype model_kwargs = {} - if embedding_dtype == "float16": + if self.embedding_dtype == "float16": model_kwargs["torch_dtype"] = torch.float16 - elif embedding_dtype == "bfloat16": + elif self.embedding_dtype == "bfloat16": model_kwargs["torch_dtype"] = torch.bfloat16 - elif embedding_dtype != "float32": + elif self.embedding_dtype != "float32": raise ValueError( "Invalid embedding_dtype. Must be 'float32', 'float16', or 'bfloat16'." ) - # Load the embedding model - embedding_model = SentenceTransformer( - embedding_model_name, - device=embedding_device, - trust_remote_code=True, - model_kwargs=model_kwargs, - ) - embedding_model.max_seq_length = max_seq_length + if self.embedding_model is None: + logger.info(f"Loading embedding model: {self.embedding_model_name}...") + # Load the embedding model + self.embedding_model = SentenceTransformer( + self.embedding_model_name, + device=embedding_device, + trust_remote_code=True, + model_kwargs=model_kwargs, + ) + self.embedding_model.max_seq_length = max_seq_length + self.batch_size = batch_size + logger.debug("Embedding model loaded") # Calculate the number of batches num_batches = (len(texts) + batch_size - 1) // batch_size @@ -140,19 +146,42 @@ def _local_embed_documents( start_idx = i * batch_size end_idx = min(start_idx + batch_size, len(texts)) batch_texts = texts[start_idx:end_idx] - batch_embeddings = embedding_model.encode( - batch_texts, show_progress_bar=False + batch_embeddings = self.embedding_model.encode( + batch_texts, + show_progress_bar=False, + output_value=None, # to get all output values, not only sentence embeddings ) embeddings.append(batch_embeddings) # Concatenate all batch embeddings - embeddings = np.concatenate(embeddings, axis=0) + all_embeddings = np.concatenate(embeddings, axis=0) + logger.success(f"Embedded {len(texts)} documents in {num_batches} batches") + + embeddings = [ + item["sentence_embedding"].detach().cpu() for item in all_embeddings + ] + embeddings = torch.stack(embeddings) + embeddings = embeddings.numpy() + + token_embeddings = [ + item["token_embeddings"].detach().cpu() for item in all_embeddings + ] + token_ids = [item["input_ids"].detach().cpu() for item in all_embeddings] + + token_embeddings = convert_to_numpy(token_embeddings) + token_ids = convert_to_numpy(token_ids, type="token_id") + + tokenizer = self.embedding_model._first_module().tokenizer + + token_strings, token_embeddings = group_tokens( + tokenizer, token_ids, token_embeddings, language="french" + ) - return embedding_model, embeddings + return self.embedding_model, embeddings, token_strings, token_embeddings def _remote_embed_documents( self, texts: List[str], show_progress_bar: bool = True - ) -> Tuple[str, np.ndarray]: + ) -> Tuple[str, np.ndarray, None, None]: """ Embed a list of documents using a Sentence Transformer model. @@ -185,7 +214,7 @@ def _remote_embed_documents( if response.status_code == 200: embeddings = np.array(response.json()["embeddings"]) logger.debug(f"Computing embeddings done for batch") - return self._get_remote_model_name(), embeddings + return self._get_remote_model_name(), embeddings, None, None else: logger.error(f"Error: {response.status_code}") raise Exception(f"Error: {response.status_code}") @@ -204,3 +233,86 @@ def _get_remote_model_name(self) -> str: else: logger.error(f"Error: {response.status_code}") raise Exception(f"Error: {response.status_code}") + + +def convert_to_numpy(obj, type=None): + """ + Convert a torch.Tensor or list of torch.Tensors to numpy arrays. + Args: + obj: The object to convert (torch.Tensor or list). + type: The type of conversion (optional, used for token ids). + Returns: + np.ndarray or list of np.ndarray. + """ + if isinstance(obj, torch.Tensor): + return ( + obj.numpy().astype(np.int64) + if type == "token_id" + else obj.numpy().astype(np.float32) + ) + elif isinstance(obj, list): + return [convert_to_numpy(item) for item in obj] + else: + raise TypeError("Object must be a list or torch.Tensor") + + +def group_tokens(tokenizer, token_ids, token_embeddings, language="french"): + """ + Group split tokens into whole words and average their embeddings. + Args: + tokenizer: The tokenizer to use for converting ids to tokens. + token_ids: List of token ids. + token_embeddings: List of token embeddings. + language: The language of the tokens (default is "french"). + Returns: + List of grouped tokens and their corresponding embeddings. + """ + grouped_token_lists = [] + grouped_embedding_lists = [] + + special_tokens = { + "english": ["[CLS]", "[SEP]", "[PAD]"], + "french": ["", "", ""], + } + subword_prefix = {"english": "##", "french": "▁"} + + for token_id, token_embedding in tqdm( + zip(token_ids, token_embeddings), desc="Grouping split tokens into whole words" + ): + tokens = tokenizer.convert_ids_to_tokens(token_id) + + grouped_tokens = [] + grouped_embeddings = [] + current_word = "" + current_embedding = [] + + for token, embedding in zip(tokens, token_embedding): + if token in special_tokens[language]: + continue + + if language == "french" and token.startswith(subword_prefix[language]): + if current_word: + grouped_tokens.append(current_word) + grouped_embeddings.append(np.mean(current_embedding, axis=0)) + current_word = token[1:] + current_embedding = [embedding] + elif language == "english" and not token.startswith( + subword_prefix[language] + ): + if current_word: + grouped_tokens.append(current_word) + grouped_embeddings.append(np.mean(current_embedding, axis=0)) + current_word = token + current_embedding = [embedding] + else: + current_word += token.lstrip(subword_prefix[language]) + current_embedding.append(embedding) + + if current_word: + grouped_tokens.append(current_word) + grouped_embeddings.append(np.mean(current_embedding, axis=0)) + + grouped_token_lists.append(grouped_tokens) + grouped_embedding_lists.append(np.array(grouped_embeddings)) + + return grouped_token_lists, grouped_embedding_lists diff --git a/bertrend/train.py b/bertrend/train.py index 083cca0..6d8c9c5 100644 --- a/bertrend/train.py +++ b/bertrend/train.py @@ -25,6 +25,7 @@ from bertrend import BASE_CACHE_PATH, LLM_CONFIG from bertrend.parameters import STOPWORDS from bertrend.llm_utils.openai_client import OpenAI_Client +from bertrend.services.embedding_service import convert_to_numpy, group_tokens from bertrend.utils.data_loading import TEXT_COLUMN from bertrend.llm_utils.prompts import BERTOPIC_FRENCH_TOPIC_REPRESENTATION_PROMPT @@ -109,89 +110,6 @@ def embed(self, documents: Union[List[str], pd.Series], verbose=True) -> np.ndar return all_embeddings -def convert_to_numpy(obj, type=None): - """ - Convert a torch.Tensor or list of torch.Tensors to numpy arrays. - Args: - obj: The object to convert (torch.Tensor or list). - type: The type of conversion (optional, used for token ids). - Returns: - np.ndarray or list of np.ndarray. - """ - if isinstance(obj, torch.Tensor): - return ( - obj.numpy().astype(np.int64) - if type == "token_id" - else obj.numpy().astype(np.float32) - ) - elif isinstance(obj, list): - return [convert_to_numpy(item) for item in obj] - else: - raise TypeError("Object must be a list or torch.Tensor") - - -def group_tokens(tokenizer, token_ids, token_embeddings, language="french"): - """ - Group split tokens into whole words and average their embeddings. - Args: - tokenizer: The tokenizer to use for converting ids to tokens. - token_ids: List of token ids. - token_embeddings: List of token embeddings. - language: The language of the tokens (default is "french"). - Returns: - List of grouped tokens and their corresponding embeddings. - """ - grouped_token_lists = [] - grouped_embedding_lists = [] - - special_tokens = { - "english": ["[CLS]", "[SEP]", "[PAD]"], - "french": ["", "", ""], - } - subword_prefix = {"english": "##", "french": "▁"} - - for token_id, token_embedding in tqdm( - zip(token_ids, token_embeddings), desc="Grouping split tokens into whole words" - ): - tokens = tokenizer.convert_ids_to_tokens(token_id) - - grouped_tokens = [] - grouped_embeddings = [] - current_word = "" - current_embedding = [] - - for token, embedding in zip(tokens, token_embedding): - if token in special_tokens[language]: - continue - - if language == "french" and token.startswith(subword_prefix[language]): - if current_word: - grouped_tokens.append(current_word) - grouped_embeddings.append(np.mean(current_embedding, axis=0)) - current_word = token[1:] - current_embedding = [embedding] - elif language == "english" and not token.startswith( - subword_prefix[language] - ): - if current_word: - grouped_tokens.append(current_word) - grouped_embeddings.append(np.mean(current_embedding, axis=0)) - current_word = token - current_embedding = [embedding] - else: - current_word += token.lstrip(subword_prefix[language]) - current_embedding.append(embedding) - - if current_word: - grouped_tokens.append(current_word) - grouped_embeddings.append(np.mean(current_embedding, axis=0)) - - grouped_token_lists.append(grouped_tokens) - grouped_embedding_lists.append(np.array(grouped_embeddings)) - - return grouped_token_lists, grouped_embedding_lists - - def remove_special_tokens(tokenizer, token_id, token_embedding, special_tokens): """ Remove special tokens from the token ids and embeddings. diff --git a/bertrend_apps/newsletters/__main__.py b/bertrend_apps/newsletters/__main__.py index 3fc7b00..e4fc59c 100644 --- a/bertrend_apps/newsletters/__main__.py +++ b/bertrend_apps/newsletters/__main__.py @@ -24,6 +24,7 @@ from bertrend import FEED_BASE_PATH, BEST_CUDA_DEVICE, OUTPUT_PATH from bertrend.parameters import BERTOPIC_SERIALIZATION +from bertrend.services.embedding_service import EmbeddingService from bertrend.utils.config_utils import load_toml_config from bertrend.utils.data_loading import ( enhanced_split_df_by_paragraphs, @@ -35,7 +36,7 @@ generate_newsletter, export_md_string, ) -from bertrend.train import EmbeddingModel, train_BERTopic +from bertrend.train import train_BERTopic from bertrend_apps.common.mail_utils import get_credentials, send_email from bertrend_apps.common.crontab_utils import schedule_newsletter @@ -122,7 +123,7 @@ def newsletter_from_feed( topic_model = _load_topic_model(model_path) logger.info(f"Topic model loaded from {model_path}") logger.info("Computation of embeddings for new data...") - embeddings = EmbeddingModel( + embeddings = EmbeddingService( config.get("topic_model.embedding", "model_name") ).embed(dataset[TEXT_COLUMN]) topics, probs = topic_model.transform(dataset[TEXT_COLUMN], embeddings)