From 610ffae51d0f36d4af9f1265321334c0c53c0a87 Mon Sep 17 00:00:00 2001 From: sahusiddharth Date: Sun, 20 Aug 2023 03:48:42 +0530 Subject: [PATCH] Fixed issue with WordPieceTokenizer by adding vocab_size argument --- keras_nlp/tokenizers/word_piece_tokenizer.py | 53 +++++++++++++++----- 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py index dc9ce49427..34bfd4041f 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer.py @@ -27,6 +27,7 @@ from keras_nlp.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.utils.tensor_utils import is_integer_dtype from keras_nlp.utils.tensor_utils import is_string_dtype +from absl import logging try: import tensorflow_text as tf_text @@ -202,6 +203,8 @@ class WordPieceTokenizer(tokenizer.Tokenizer): plain text file containing a single WordPiece token per line. sequence_length: int. If set, the output will be converted to a dense tensor and padded/trimmed so all outputs are of sequence_length. + vocab_size: int. If set, force vocabulary to be exactly vocabulary_size, + by truncating the input vocabulary if necessary. lowercase: bool. If `True`, the input text will be lowercased before tokenization. Defaults to `False`. strip_accents: bool. If `True`, all accent marks will @@ -294,6 +297,7 @@ def __init__( self, vocabulary=None, sequence_length: int = None, + vocab_size: int = None, lowercase: bool = False, strip_accents: bool = False, split: bool = True, @@ -313,13 +317,46 @@ def __init__( super().__init__(dtype=dtype, **kwargs) + self.vocab_size = vocab_size + self.sequence_length = sequence_length + self.lowercase = lowercase + self.strip_accents = strip_accents + self.split = split + self.split_on_cjk = split_on_cjk + self.suffix_indicator = suffix_indicator + self.oov_token = oov_token + if isinstance(vocabulary, str): - self.vocabulary = [ + vocabulary_list = [ line.rstrip() for line in tf.io.gfile.GFile(vocabulary) ] + input_vocabulary_size = len(vocabulary_list) + if self.vocab_size == None: + self.vocab_size = input_vocabulary_size + self.vocabulary = vocabulary_list + elif self.vocab_size < input_vocabulary_size: + logging.warning( + "Setting vocab size to a smaller value than the input vocabulary file." + "Some token ids will never be output from the tokenizer." + ) + self.vocabulary = vocabulary_list[:self.vocab_size] + else: + self.vocab_size = input_vocabulary_size + self.vocabulary = vocabulary_list elif isinstance(vocabulary, Iterable): - # Make a copy. - self.vocabulary = list(vocabulary) + input_vocabulary_size = len(vocabulary) + if self.vocab_size == None: + self.vocab_size = input_vocabulary_size + self.vocabulary = list(vocabulary) + elif self.vocab_size < input_vocabulary_size: + logging.warning( + "Setting vocab size to a smaller value than the input vocabulary file." + "Some token ids will never be output from the tokenizer." + ) + self.vocabulary = list(vocabulary[:self.vocab_size]) + else: + self.vocab_size = input_vocabulary_size + self.vocabulary = list(vocabulary) else: raise ValueError( "Vocabulary must be an file path or list of terms. " @@ -328,14 +365,6 @@ def __init__( if oov_token is None: raise ValueError("`oov_token` cannot be None.") - self.sequence_length = sequence_length - self.lowercase = lowercase - self.strip_accents = strip_accents - self.split = split - self.split_on_cjk = split_on_cjk - self.suffix_indicator = suffix_indicator - self.oov_token = oov_token - if oov_token not in self.vocabulary: raise ValueError( f'Cannot find `oov_token="{self.oov_token}"` in the ' @@ -360,7 +389,7 @@ def get_vocabulary(self) -> List[str]: def vocabulary_size(self) -> int: """Get the size of the tokenizer vocabulary.""" - return len(self.vocabulary) + return self.vocab_size def id_to_token(self, id: int) -> str: """Convert an integer id to a string token."""