From aedf40a55dbe324840186e870fc0e414bb94d0c3 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Thu, 5 Sep 2019 15:01:55 -0400 Subject: [PATCH] Fix potential position embedding issue in decoders (#187) * Resolve #158 --- examples/seq2seq_attn/seq2seq_attn.py | 2 +- texar/torch/modules/decoders/decoder_base.py | 44 ++++++++++++++++--- .../modules/decoders/rnn_decoder_base.py | 6 ++- texar/torch/modules/decoders/rnn_decoders.py | 36 +++++++++------ .../modules/decoders/transformer_decoders.py | 27 +++++------- texar/torch/modules/decoders/xlnet_decoder.py | 20 +++++---- 6 files changed, 89 insertions(+), 46 deletions(-) diff --git a/examples/seq2seq_attn/seq2seq_attn.py b/examples/seq2seq_attn/seq2seq_attn.py index 2c37fb3da..e2b688a3a 100644 --- a/examples/seq2seq_attn/seq2seq_attn.py +++ b/examples/seq2seq_attn/seq2seq_attn.py @@ -99,7 +99,7 @@ def forward(self, batch, mode): infer_outputs = self.decoder( start_tokens=start_tokens, - end_token=self.eos_token_id.item(), + end_token=self.eos_token_id, memory=memory, memory_sequence_length=batch['source_length'], beam_width=config_model.beam_width) diff --git a/texar/torch/modules/decoders/decoder_base.py b/texar/torch/modules/decoders/decoder_base.py index 551001278..0a316bb3d 100644 --- a/texar/torch/modules/decoders/decoder_base.py +++ b/texar/torch/modules/decoders/decoder_base.py @@ -27,6 +27,7 @@ from texar.torch.modules.decoders import decoder_helpers as helpers from texar.torch.modules.decoders.decoder_helpers import Helper from texar.torch.utils import utils +from texar.torch.utils.dtypes import torch_bool __all__ = [ '_make_output_layer', @@ -405,8 +406,23 @@ def dynamic_decode(self, helper: Helper, inputs: Optional[torch.Tensor], while (not torch.all(finished).item() and (max_decoding_length is None or time < max_decoding_length)): - (next_outputs, decoder_state, next_inputs, - decoder_finished) = self.step(helper, time, step_inputs, state) + + next_outputs, decoder_state = \ + self.step(helper, time, step_inputs, state) + + if max_decoding_length is not None and \ + time + 1 == max_decoding_length: + # Maximum decoding length reached, mark all batches as finished. + # This requires special handling because performing lookup on + # position embeddings with `time + 1` may result in IndexError. + decoder_finished = torch.tensor(1, dtype=torch_bool, + device=finished.device) + # Since `next_inputs` will not be used, simply create a null + # tensor. + next_inputs = torch.empty(0) + else: + next_inputs, decoder_finished = self.next_inputs( + helper, time, next_outputs) if getattr(self, 'tracks_own_finished', False): next_finished = decoder_finished @@ -482,8 +498,9 @@ def initialize(self, helper: Helper, inputs: Optional[torch.Tensor], @abstractmethod def step(self, helper: Helper, time: int, inputs: torch.Tensor, state: Optional[State]) \ - -> Tuple[Output, State, torch.Tensor, torch.ByteTensor]: - r"""Called per step of decoding (but only once for dynamic decoding). + -> Tuple[Output, State]: + r"""Compute the output and the state at the current time step. + Called per step of decoding (but only once for dynamic decoding). Args: helper: The :class:`~texar.torch.modules.Helper` instance to use. @@ -492,10 +509,27 @@ def step(self, helper: Helper, time: int, state: Decoder state from the previous time step. Returns: - A tuple ``(outputs, next_state, next_inputs, finished)``. + A tuple ``(outputs, next_state)``. - ``outputs`` is an object containing the decoder output. - ``next_state`` is the decoder state for the next time step. + """ + raise NotImplementedError + + @abstractmethod + def next_inputs(self, helper: Helper, time: int, outputs: Output) -> \ + Tuple[torch.Tensor, torch.ByteTensor]: + r"""Compute the input for the next time step. + Called per step of decoding (but only once for dynamic decoding). + + Args: + helper: The :class:`~texar.torch.modules.Helper` instance to use. + time (int): Current step number. + outputs: An object containing the decoder output. + + Returns: + A tuple ``(next_inputs, finished)``. + - ``next_inputs`` is the tensor that should be used as input for the next step. - ``finished`` is a :torch:`ByteTensor` tensor telling whether the diff --git a/texar/torch/modules/decoders/rnn_decoder_base.py b/texar/torch/modules/decoders/rnn_decoder_base.py index 19f73df4c..3784ee930 100644 --- a/texar/torch/modules/decoders/rnn_decoder_base.py +++ b/texar/torch/modules/decoders/rnn_decoder_base.py @@ -209,7 +209,11 @@ def initialize(self, helper: Helper, inputs: Optional[torch.Tensor], def step(self, helper: Helper, time: int, inputs: torch.Tensor, state: Optional[State]) \ - -> Tuple[Output, State, torch.Tensor, torch.ByteTensor]: + -> Tuple[Output, State]: + raise NotImplementedError + + def next_inputs(self, helper: Helper, time: int, outputs: Output) -> \ + Tuple[torch.Tensor, torch.ByteTensor]: raise NotImplementedError @property diff --git a/texar/torch/modules/decoders/rnn_decoders.py b/texar/torch/modules/decoders/rnn_decoders.py index bdc95a265..ba0f8410b 100644 --- a/texar/torch/modules/decoders/rnn_decoders.py +++ b/texar/torch/modules/decoders/rnn_decoders.py @@ -236,18 +236,22 @@ def default_hparams(): hparams['name'] = 'basic_rnn_decoder' return hparams - def step(self, helper: Helper, time: int, - inputs: torch.Tensor, state: Optional[HiddenState]) \ - -> Tuple[BasicRNNDecoderOutput, HiddenState, - torch.Tensor, torch.ByteTensor]: + def step(self, helper: Helper, time: int, inputs: torch.Tensor, + state: Optional[HiddenState]) \ + -> Tuple[BasicRNNDecoderOutput, HiddenState]: cell_outputs, cell_state = self._cell(inputs, state) logits = self._output_layer(cell_outputs) sample_ids = helper.sample(time=time, outputs=logits) - (finished, next_inputs) = helper.next_inputs( - self.embed_tokens, time, logits, sample_ids) next_state = cell_state outputs = BasicRNNDecoderOutput(logits, sample_ids, cell_outputs) - return outputs, next_state, next_inputs, finished + return outputs, next_state + + def next_inputs(self, helper: Helper, time: int, + outputs: BasicRNNDecoderOutput) -> \ + Tuple[torch.Tensor, torch.ByteTensor]: + finished, next_inputs = helper.next_inputs( + self.embed_tokens, time, outputs.logits, outputs.sample_id) + return next_inputs, finished @property def output_size(self): @@ -565,18 +569,15 @@ def initialize( # type: ignore return initial_finished, initial_inputs, state - def step(self, helper: Helper, time: int, - inputs: torch.Tensor, state: Optional[AttentionWrapperState]) -> \ - Tuple[AttentionRNNDecoderOutput, AttentionWrapperState, - torch.Tensor, torch.ByteTensor]: + def step(self, helper: Helper, time: int, inputs: torch.Tensor, + state: Optional[AttentionWrapperState]) -> \ + Tuple[AttentionRNNDecoderOutput, AttentionWrapperState]: wrapper_outputs, wrapper_state = self._cell( inputs, state, self.memory, self.memory_sequence_length) # Essentially the same as in BasicRNNDecoder.step() logits = self._output_layer(wrapper_outputs) sample_ids = helper.sample(time=time, outputs=logits) - finished, next_inputs = helper.next_inputs( - self.embed_tokens, time, logits, sample_ids) attention_scores = wrapper_state.alignments attention_context = wrapper_state.attention @@ -585,7 +586,14 @@ def step(self, helper: Helper, time: int, attention_scores, attention_context) next_state = wrapper_state - return outputs, next_state, next_inputs, finished + return outputs, next_state + + def next_inputs(self, helper: Helper, time: int, + outputs: AttentionRNNDecoderOutput) -> \ + Tuple[torch.Tensor, torch.ByteTensor]: + finished, next_inputs = helper.next_inputs( + self.embed_tokens, time, outputs.logits, outputs.sample_id) + return next_inputs, finished def forward( # type: ignore self, diff --git a/texar/torch/modules/decoders/transformer_decoders.py b/texar/torch/modules/decoders/transformer_decoders.py index 3225bc7ba..5238710b7 100644 --- a/texar/torch/modules/decoders/transformer_decoders.py +++ b/texar/torch/modules/decoders/transformer_decoders.py @@ -32,7 +32,6 @@ from texar.torch.modules.networks.networks import FeedForwardNetwork from texar.torch.utils import transformer_attentions as attn from texar.torch.utils.beam_search import beam_search -from texar.torch.utils.dtypes import torch_bool from texar.torch.utils.shapes import mask_sequences from texar.torch.utils.utils import sequence_mask @@ -731,10 +730,9 @@ def initialize(self, helper: Helper, inputs: Optional[torch.Tensor], state = initial_state or self._state_cache return initial_finished, initial_inputs, state - def step(self, helper: Helper, time: int, - inputs: torch.Tensor, state: Optional[Cache]) \ - -> Tuple[TransformerDecoderOutput, Cache, - torch.Tensor, torch.ByteTensor]: + def step(self, helper: Helper, time: int, inputs: torch.Tensor, + state: Optional[Cache]) -> \ + Tuple[TransformerDecoderOutput, Cache]: assert state is not None outputs, state = self._inputs_to_outputs(inputs, state) sample_ids = helper.sample(time=time, outputs=outputs) @@ -745,21 +743,18 @@ def step(self, helper: Helper, time: int, self._state_context[:, time], sample_ids) - if time + 1 == self._state_max_decoding_length: - # Maximum decoding length reached, mark all batches as finished. - # This requires special handling because performing lookup on - # position embeddings with `time + 1` may result in IndexError. - finished = torch.ones_like(sample_ids, dtype=torch_bool) - # Since `next_inputs` will not be used, simply create a null tensor. - next_inputs = torch.empty(0) - else: - finished, next_inputs = helper.next_inputs( - self.embed_tokens, time, outputs, sample_ids) next_state = state outputs = TransformerDecoderOutput( logits=outputs, sample_id=sample_ids) - return outputs, next_state, next_inputs, finished + return outputs, next_state + + def next_inputs(self, helper: Helper, time: int, + outputs: TransformerDecoderOutput) -> \ + Tuple[torch.Tensor, torch.ByteTensor]: + finished, next_inputs = helper.next_inputs( + self.embed_tokens, time, outputs.logits, outputs.sample_id) + return next_inputs, finished def finalize(self, # type: ignore outputs: TransformerDecoderOutput, diff --git a/texar/torch/modules/decoders/xlnet_decoder.py b/texar/torch/modules/decoders/xlnet_decoder.py index 11ba31007..e59035013 100644 --- a/texar/torch/modules/decoders/xlnet_decoder.py +++ b/texar/torch/modules/decoders/xlnet_decoder.py @@ -257,12 +257,9 @@ def initialize(self, self.embed_tokens, inputs, sequence_length) return initial_finished, initial_inputs, initial_state - def step(self, - helper: Helper, - time: int, - inputs: torch.Tensor, - state: Optional[State]) \ - -> Tuple[Output, Optional[State], torch.Tensor, torch.ByteTensor]: + def step(self, helper: Helper, time: int, inputs: torch.Tensor, + state: Optional[State]) -> \ + Tuple[Output, Optional[State]]: self._state_previous_inputs.append(inputs) if self._state_recompute_memory: net_output, memory = self._forward( @@ -281,10 +278,15 @@ def step(self, logits = F.linear(net_output, self.word_embed.weight, self.lm_bias) logits = logits[:, -1] sample_ids = helper.sample(time=time, outputs=logits) - (finished, next_inputs) = helper.next_inputs( - self.embed_tokens, time, logits, sample_ids) outputs = XLNetDecoderOutput(logits=logits, sample_id=sample_ids) - return outputs, memory, next_inputs, finished + return outputs, memory + + def next_inputs(self, helper: Helper, time: int, + outputs: Output) -> \ + Tuple[torch.Tensor, torch.ByteTensor]: + finished, next_inputs = helper.next_inputs( + self.embed_tokens, time, outputs.logits, outputs.sample_id) + return next_inputs, finished def finalize(self, outputs, final_state, sequence_lengths): del self._state_cache_len