From 0a073899735b57c974c235ebdb40af81117fe4d0 Mon Sep 17 00:00:00 2001 From: Erland366 Date: Thu, 31 Oct 2024 16:33:18 +0400 Subject: [PATCH] feat: enhance add_new_tokens function with resizing option and improved validation --- unsloth_zoo/tokenizer_utils.py | 38 +++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/tokenizer_utils.py b/unsloth_zoo/tokenizer_utils.py index aadd2e2..0964e29 100644 --- a/unsloth_zoo/tokenizer_utils.py +++ b/unsloth_zoo/tokenizer_utils.py @@ -73,15 +73,36 @@ def add_new_tokens( new_tokens = [], method = "mean", interpolation = 0.5, + resize_tokenizer: bool = True, ): """ Smartly resizes the tokenizer and adds new tokens to the model. We also disregard untrained tokens by removing them from the mean calculation. + + Parameters: + ----------- + model : PreTrainedModel + The model to which new tokens will be added. + tokenizer : PreTrainedTokenizer + The tokenizer to be resized and updated with new tokens. + new_tokens : list or tuple, optional + A list or tuple of new tokens to be added. Default is an empty list. + method : str, optional + The method to use for adding new tokens. Can be "mean" or "interpolation". Default is "mean". + interpolation : float, optional + The interpolation factor to use when method is "interpolation". Should be between 0 and 1. Default is 0.5. + resized_tokenizer : bool, optional + A flag indicating whether the tokenizer should be resized. This is useful if the tokenizer is already + resized with respect to the model. Default is True. """ assert(isinstance(new_tokens, (list, tuple))) - assert(len(new_tokens) > 0) assert(method == "mean" or method == "interpolation") assert(interpolation >= 0 and interpolation <= 1) + # check if new tokens are not empty only if resized_tokenizer is True + if resize_tokenizer and len(new_tokens) == 0: + raise ValueError( + "Unsloth: You are trying to add new tokens to the model and tokenizer, but the new tokens list is empty." + ) # Check if tokens already exist overlapping_tokens = set(new_tokens) & set(tokenizer.vocab.keys()) @@ -108,8 +129,19 @@ def add_new_tokens( old_config_size = model.config.vocab_size # Add tokens! - old_length = len(tokenizer) - tokenizer.add_tokens(new_tokens) + if resize_tokenizer: + old_length = len(tokenizer) + tokenizer.add_tokens(new_tokens) + else: + # Old length means the model vocabulary itself if tokenizer is already resized + old_length = old_config_size + + # Sort by id + new_vocab_tokenizer = sorted(tokenizer.get_vocab().items(),key=lambda x: x[1]) + + # Get the new tokens + new_tokens = [x[0] for x in new_vocab_tokenizer[old_length:]] + # Also resizes lm_head as well! model.resize_token_embeddings(len(tokenizer))