From 73e6ee16d207e889a36ef3bee27349edad188678 Mon Sep 17 00:00:00 2001 From: George Armstrong Date: Fri, 25 Oct 2024 16:54:49 -0400 Subject: [PATCH] fix: correct batch tokenization when sequence exceeds encoder length (#352) Signed-off-by: George Armstrong Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- nemo_aligner/utils/text_generation_utils.py | 2 +- tests/test_text_generation_utils.py | 69 +++++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 tests/test_text_generation_utils.py diff --git a/nemo_aligner/utils/text_generation_utils.py b/nemo_aligner/utils/text_generation_utils.py index d4d60f7c4..817f5dc45 100644 --- a/nemo_aligner/utils/text_generation_utils.py +++ b/nemo_aligner/utils/text_generation_utils.py @@ -92,10 +92,10 @@ def tokenize(sentence): return output context_tokens = list(map(tokenize, sentences)) + context_tokens = [x[:max_len] for x in context_tokens] max_sequence_length = max(len(x) for x in context_tokens) context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eos_id, max_len - max_sequence_length) - context_tokens = [x[:max_len] for x in context_tokens] context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor(context_lengths) return context_tokens_tensor, context_length_tensor diff --git a/tests/test_text_generation_utils.py b/tests/test_text_generation_utils.py new file mode 100644 index 000000000..26b90a53b --- /dev/null +++ b/tests/test_text_generation_utils.py @@ -0,0 +1,69 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_aligner.utils.text_generation_utils import tokenize_batch + + +class MockTokenizer: + def __init__(self): + self.vocab = dict() + self.bos_id = 0 + self.eos_id = 1 + self.vocab[""] = self.bos_id + self.vocab[""] = self.eos_id + + def text_to_ids(self, text): + tokens = list(text) + ids = [self.vocab.get(token, len(self.vocab)) for token in tokens] + return ids + + +def test_tokenize_batch(): + sentences = ["I went to the store.", "I bought a zoo."] + tokenizer = MockTokenizer() + max_len = 30 + context_tokens_tensor, context_length_tensor = tokenize_batch( + sentences, tokenizer, max_len, add_BOS=False, add_EOS=False + ) + assert context_tokens_tensor.shape == ( + 2, + 30, + ), f"expected context_tokens_tensor shape to be (2, 30) but got {context_tokens_tensor.shape}" + assert context_length_tensor.shape == ( + 2, + ), f"expected context_length_tensor shape to be (2,) but got {context_length_tensor.shape}" + assert context_length_tensor.tolist() == [ + 20, + 15, + ], f"expected context_length_tensor to be [20, 15] but got {context_length_tensor.tolist()}" + + +def test_tokenize_batch_with_sentence_longer_than_max_len(): + sentences = ["I went to the store.", "I bought a zoo."] + tokenizer = MockTokenizer() + max_len = 10 + context_tokens_tensor, context_length_tensor = tokenize_batch( + sentences, tokenizer, max_len, add_BOS=False, add_EOS=False + ) + assert context_tokens_tensor.shape == ( + 2, + 10, + ), f"expected context_tokens_tensor shape to be (2, 10) but got {context_tokens_tensor.shape}" + assert context_length_tensor.shape == ( + 2, + ), f"expected context_length_tensor shape to be (2,) but got {context_length_tensor.shape}" + assert context_length_tensor.tolist() == [ + 10, + 10, + ], f"expected context_length_tensor to be [10, 10] but got {context_length_tensor.tolist()}"