Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Generation should stop after two new lines if that is the stop criteria
Browse files Browse the repository at this point in the history
Summary: This addresses Issue 642. When the stop token is \n\n the generation
should stop after generation two new lines.
  • Loading branch information
Vidyaranya committed May 2, 2023
1 parent edefd4a commit d704750
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
20 changes: 18 additions & 2 deletions metaseq/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ def decode(self, sentence: str) -> str:
return self.tokenizer.decode(sentence)


class RecurringPunctuation(object):
"""Class for groping tokens of similar type. For example \n and \n\n"""

def __init__(self, single_token, multiple_token):
super().__init__()
self.single_token = single_token
self.multiple_token = multiple_token


class GeneratorInterface:
"""
PyTorch Hub interface for generating sequences from a pre-trained
Expand Down Expand Up @@ -323,14 +332,21 @@ def generate(
self.cfg.generation,
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
)

# okay actually generate
logger.info(f"Executing generation on input tensor size {src_tokens.shape}")
if use_cuda:
batch = utils.move_to_cuda(batch)

translate_start_time = time.time()
translations = self.task.inference_step(generator, self.models, batch)
recurring_punctuation = RecurringPunctuation(
self.bpe.bpe.encode("\n").ids[0], self.bpe.bpe.encode("\n\n").ids[0]
)
translations = self.task.inference_step(
generator,
self.models,
batch,
recurring_punctuation=recurring_punctuation,
)
translate_time = time.time() - translate_start_time
total_generation_time += translate_time

Expand Down
11 changes: 11 additions & 0 deletions metaseq/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from metaseq import utils
from metaseq.data import data_utils
from metaseq.models import BaseDecoder
from metaseq.metaseq.hub_utils import RecurringPunctuation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -130,6 +131,7 @@ def forward(
sample: Dict[str, Dict[str, Tensor]],
prefix_tokens: Optional[Tensor] = None,
bos_token: Optional[int] = None,
recurring_punctuation: Optional[RecurringPunctuation] = None,
):
"""Generate a batch of translations."""
return self._generate(sample, prefix_tokens, bos_token=bos_token)
Expand All @@ -144,6 +146,7 @@ def _generate(
sample: Dict[str, Dict[str, Tensor]],
prefix_tokens: Optional[Tensor] = None,
bos_token: Optional[int] = None,
recurring_punctuation: Optional[RecurringPunctuation] = None,
):
"""
Args:
Expand Down Expand Up @@ -268,6 +271,7 @@ def _generate(

eos_mask = torch.zeros(lprobs.size(0), dtype=torch.bool, device=lprobs.device)

prev_token = None
for step in range(start_step, max_len):
if step < min_len:
# minimum length constraint (does not apply if using prefix_tokens)
Expand Down Expand Up @@ -303,13 +307,20 @@ def _generate(
all_lprobs[:, step] = lprobs

eos_mask |= next_toks == self.eos

for stop_token in self.stop:
# if there are other early stopping tokens, allow those to trigger stop
eos_mask |= next_toks == stop_token
eos_mask |= (
recurring_punctuation
and recurring_punctuation.multiple_token == stop_token
and recurring_punctuation.single_token == next_toks == prev_token
)

if torch.all(eos_mask):
break

prev_token = next_toks
# forward through the next pass
model_out = self.model.decoder(
tokens[:, : step + 1],
Expand Down
11 changes: 9 additions & 2 deletions metaseq/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,16 @@ def build_dataset_for_inference(
) -> torch.utils.data.Dataset:
raise NotImplementedError

def inference_step(self, generator, models, sample, prefix_tokens=None):
def inference_step(
self, generator, models, sample, prefix_tokens=None, recurring_punctuation=None
):
with torch.no_grad():
return generator.generate(models, sample, prefix_tokens=prefix_tokens)
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
recurring_punctuation=recurring_punctuation,
)

def begin_epoch(self, epoch, model):
"""Hook function called before the start of each epoch."""
Expand Down

0 comments on commit d704750

Please sign in to comment.