From d704750b0fcdd74e65414b685e1927368c9b0f5d Mon Sep 17 00:00:00 2001 From: Vidyaranya Date: Sun, 23 Apr 2023 06:52:33 -0700 Subject: [PATCH] Generation should stop after two new lines if that is the stop criteria Summary: This addresses Issue 642. When the stop token is \n\n the generation should stop after generation two new lines. --- metaseq/hub_utils.py | 20 ++++++++++++++++++-- metaseq/sequence_generator.py | 11 +++++++++++ metaseq/tasks/base_task.py | 11 +++++++++-- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/metaseq/hub_utils.py b/metaseq/hub_utils.py index a6690bb86..3397cd768 100644 --- a/metaseq/hub_utils.py +++ b/metaseq/hub_utils.py @@ -82,6 +82,15 @@ def decode(self, sentence: str) -> str: return self.tokenizer.decode(sentence) +class RecurringPunctuation(object): + """Class for groping tokens of similar type. For example \n and \n\n""" + + def __init__(self, single_token, multiple_token): + super().__init__() + self.single_token = single_token + self.multiple_token = multiple_token + + class GeneratorInterface: """ PyTorch Hub interface for generating sequences from a pre-trained @@ -323,14 +332,21 @@ def generate( self.cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs, ) - # okay actually generate logger.info(f"Executing generation on input tensor size {src_tokens.shape}") if use_cuda: batch = utils.move_to_cuda(batch) translate_start_time = time.time() - translations = self.task.inference_step(generator, self.models, batch) + recurring_punctuation = RecurringPunctuation( + self.bpe.bpe.encode("\n").ids[0], self.bpe.bpe.encode("\n\n").ids[0] + ) + translations = self.task.inference_step( + generator, + self.models, + batch, + recurring_punctuation=recurring_punctuation, + ) translate_time = time.time() - translate_start_time total_generation_time += translate_time diff --git a/metaseq/sequence_generator.py b/metaseq/sequence_generator.py index 047e35a52..60ef8378a 100644 --- a/metaseq/sequence_generator.py +++ b/metaseq/sequence_generator.py @@ -19,6 +19,7 @@ from metaseq import utils from metaseq.data import data_utils from metaseq.models import BaseDecoder +from metaseq.metaseq.hub_utils import RecurringPunctuation logger = logging.getLogger(__name__) @@ -130,6 +131,7 @@ def forward( sample: Dict[str, Dict[str, Tensor]], prefix_tokens: Optional[Tensor] = None, bos_token: Optional[int] = None, + recurring_punctuation: Optional[RecurringPunctuation] = None, ): """Generate a batch of translations.""" return self._generate(sample, prefix_tokens, bos_token=bos_token) @@ -144,6 +146,7 @@ def _generate( sample: Dict[str, Dict[str, Tensor]], prefix_tokens: Optional[Tensor] = None, bos_token: Optional[int] = None, + recurring_punctuation: Optional[RecurringPunctuation] = None, ): """ Args: @@ -268,6 +271,7 @@ def _generate( eos_mask = torch.zeros(lprobs.size(0), dtype=torch.bool, device=lprobs.device) + prev_token = None for step in range(start_step, max_len): if step < min_len: # minimum length constraint (does not apply if using prefix_tokens) @@ -303,13 +307,20 @@ def _generate( all_lprobs[:, step] = lprobs eos_mask |= next_toks == self.eos + for stop_token in self.stop: # if there are other early stopping tokens, allow those to trigger stop eos_mask |= next_toks == stop_token + eos_mask |= ( + recurring_punctuation + and recurring_punctuation.multiple_token == stop_token + and recurring_punctuation.single_token == next_toks == prev_token + ) if torch.all(eos_mask): break + prev_token = next_toks # forward through the next pass model_out = self.model.decoder( tokens[:, : step + 1], diff --git a/metaseq/tasks/base_task.py b/metaseq/tasks/base_task.py index db9c1760b..44949deea 100644 --- a/metaseq/tasks/base_task.py +++ b/metaseq/tasks/base_task.py @@ -420,9 +420,16 @@ def build_dataset_for_inference( ) -> torch.utils.data.Dataset: raise NotImplementedError - def inference_step(self, generator, models, sample, prefix_tokens=None): + def inference_step( + self, generator, models, sample, prefix_tokens=None, recurring_punctuation=None + ): with torch.no_grad(): - return generator.generate(models, sample, prefix_tokens=prefix_tokens) + return generator.generate( + models, + sample, + prefix_tokens=prefix_tokens, + recurring_punctuation=recurring_punctuation, + ) def begin_epoch(self, epoch, model): """Hook function called before the start of each epoch."""