Skip to content

Commit

Permalink
Fix potential position embedding issue in decoders (#187)
Browse files Browse the repository at this point in the history
* Resolve #158
  • Loading branch information
gpengzhi authored Sep 5, 2019
1 parent 735598e commit aedf40a
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 46 deletions.
2 changes: 1 addition & 1 deletion examples/seq2seq_attn/seq2seq_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def forward(self, batch, mode):

infer_outputs = self.decoder(
start_tokens=start_tokens,
end_token=self.eos_token_id.item(),
end_token=self.eos_token_id,
memory=memory,
memory_sequence_length=batch['source_length'],
beam_width=config_model.beam_width)
Expand Down
44 changes: 39 additions & 5 deletions texar/torch/modules/decoders/decoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from texar.torch.modules.decoders import decoder_helpers as helpers
from texar.torch.modules.decoders.decoder_helpers import Helper
from texar.torch.utils import utils
from texar.torch.utils.dtypes import torch_bool

__all__ = [
'_make_output_layer',
Expand Down Expand Up @@ -405,8 +406,23 @@ def dynamic_decode(self, helper: Helper, inputs: Optional[torch.Tensor],

while (not torch.all(finished).item() and
(max_decoding_length is None or time < max_decoding_length)):
(next_outputs, decoder_state, next_inputs,
decoder_finished) = self.step(helper, time, step_inputs, state)

next_outputs, decoder_state = \
self.step(helper, time, step_inputs, state)

if max_decoding_length is not None and \
time + 1 == max_decoding_length:
# Maximum decoding length reached, mark all batches as finished.
# This requires special handling because performing lookup on
# position embeddings with `time + 1` may result in IndexError.
decoder_finished = torch.tensor(1, dtype=torch_bool,
device=finished.device)
# Since `next_inputs` will not be used, simply create a null
# tensor.
next_inputs = torch.empty(0)
else:
next_inputs, decoder_finished = self.next_inputs(
helper, time, next_outputs)

if getattr(self, 'tracks_own_finished', False):
next_finished = decoder_finished
Expand Down Expand Up @@ -482,8 +498,9 @@ def initialize(self, helper: Helper, inputs: Optional[torch.Tensor],
@abstractmethod
def step(self, helper: Helper, time: int,
inputs: torch.Tensor, state: Optional[State]) \
-> Tuple[Output, State, torch.Tensor, torch.ByteTensor]:
r"""Called per step of decoding (but only once for dynamic decoding).
-> Tuple[Output, State]:
r"""Compute the output and the state at the current time step.
Called per step of decoding (but only once for dynamic decoding).
Args:
helper: The :class:`~texar.torch.modules.Helper` instance to use.
Expand All @@ -492,10 +509,27 @@ def step(self, helper: Helper, time: int,
state: Decoder state from the previous time step.
Returns:
A tuple ``(outputs, next_state, next_inputs, finished)``.
A tuple ``(outputs, next_state)``.
- ``outputs`` is an object containing the decoder output.
- ``next_state`` is the decoder state for the next time step.
"""
raise NotImplementedError

@abstractmethod
def next_inputs(self, helper: Helper, time: int, outputs: Output) -> \
Tuple[torch.Tensor, torch.ByteTensor]:
r"""Compute the input for the next time step.
Called per step of decoding (but only once for dynamic decoding).
Args:
helper: The :class:`~texar.torch.modules.Helper` instance to use.
time (int): Current step number.
outputs: An object containing the decoder output.
Returns:
A tuple ``(next_inputs, finished)``.
- ``next_inputs`` is the tensor that should be used as input for the
next step.
- ``finished`` is a :torch:`ByteTensor` tensor telling whether the
Expand Down
6 changes: 5 additions & 1 deletion texar/torch/modules/decoders/rnn_decoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,11 @@ def initialize(self, helper: Helper, inputs: Optional[torch.Tensor],

def step(self, helper: Helper, time: int,
inputs: torch.Tensor, state: Optional[State]) \
-> Tuple[Output, State, torch.Tensor, torch.ByteTensor]:
-> Tuple[Output, State]:
raise NotImplementedError

def next_inputs(self, helper: Helper, time: int, outputs: Output) -> \
Tuple[torch.Tensor, torch.ByteTensor]:
raise NotImplementedError

@property
Expand Down
36 changes: 22 additions & 14 deletions texar/torch/modules/decoders/rnn_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,22 @@ def default_hparams():
hparams['name'] = 'basic_rnn_decoder'
return hparams

def step(self, helper: Helper, time: int,
inputs: torch.Tensor, state: Optional[HiddenState]) \
-> Tuple[BasicRNNDecoderOutput, HiddenState,
torch.Tensor, torch.ByteTensor]:
def step(self, helper: Helper, time: int, inputs: torch.Tensor,
state: Optional[HiddenState]) \
-> Tuple[BasicRNNDecoderOutput, HiddenState]:
cell_outputs, cell_state = self._cell(inputs, state)
logits = self._output_layer(cell_outputs)
sample_ids = helper.sample(time=time, outputs=logits)
(finished, next_inputs) = helper.next_inputs(
self.embed_tokens, time, logits, sample_ids)
next_state = cell_state
outputs = BasicRNNDecoderOutput(logits, sample_ids, cell_outputs)
return outputs, next_state, next_inputs, finished
return outputs, next_state

def next_inputs(self, helper: Helper, time: int,
outputs: BasicRNNDecoderOutput) -> \
Tuple[torch.Tensor, torch.ByteTensor]:
finished, next_inputs = helper.next_inputs(
self.embed_tokens, time, outputs.logits, outputs.sample_id)
return next_inputs, finished

@property
def output_size(self):
Expand Down Expand Up @@ -565,18 +569,15 @@ def initialize( # type: ignore

return initial_finished, initial_inputs, state

def step(self, helper: Helper, time: int,
inputs: torch.Tensor, state: Optional[AttentionWrapperState]) -> \
Tuple[AttentionRNNDecoderOutput, AttentionWrapperState,
torch.Tensor, torch.ByteTensor]:
def step(self, helper: Helper, time: int, inputs: torch.Tensor,
state: Optional[AttentionWrapperState]) -> \
Tuple[AttentionRNNDecoderOutput, AttentionWrapperState]:
wrapper_outputs, wrapper_state = self._cell(
inputs, state, self.memory, self.memory_sequence_length)
# Essentially the same as in BasicRNNDecoder.step()

logits = self._output_layer(wrapper_outputs)
sample_ids = helper.sample(time=time, outputs=logits)
finished, next_inputs = helper.next_inputs(
self.embed_tokens, time, logits, sample_ids)

attention_scores = wrapper_state.alignments
attention_context = wrapper_state.attention
Expand All @@ -585,7 +586,14 @@ def step(self, helper: Helper, time: int,
attention_scores, attention_context)
next_state = wrapper_state

return outputs, next_state, next_inputs, finished
return outputs, next_state

def next_inputs(self, helper: Helper, time: int,
outputs: AttentionRNNDecoderOutput) -> \
Tuple[torch.Tensor, torch.ByteTensor]:
finished, next_inputs = helper.next_inputs(
self.embed_tokens, time, outputs.logits, outputs.sample_id)
return next_inputs, finished

def forward( # type: ignore
self,
Expand Down
27 changes: 11 additions & 16 deletions texar/torch/modules/decoders/transformer_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from texar.torch.modules.networks.networks import FeedForwardNetwork
from texar.torch.utils import transformer_attentions as attn
from texar.torch.utils.beam_search import beam_search
from texar.torch.utils.dtypes import torch_bool
from texar.torch.utils.shapes import mask_sequences
from texar.torch.utils.utils import sequence_mask

Expand Down Expand Up @@ -731,10 +730,9 @@ def initialize(self, helper: Helper, inputs: Optional[torch.Tensor],
state = initial_state or self._state_cache
return initial_finished, initial_inputs, state

def step(self, helper: Helper, time: int,
inputs: torch.Tensor, state: Optional[Cache]) \
-> Tuple[TransformerDecoderOutput, Cache,
torch.Tensor, torch.ByteTensor]:
def step(self, helper: Helper, time: int, inputs: torch.Tensor,
state: Optional[Cache]) -> \
Tuple[TransformerDecoderOutput, Cache]:
assert state is not None
outputs, state = self._inputs_to_outputs(inputs, state)
sample_ids = helper.sample(time=time, outputs=outputs)
Expand All @@ -745,21 +743,18 @@ def step(self, helper: Helper, time: int,
self._state_context[:, time],
sample_ids)

if time + 1 == self._state_max_decoding_length:
# Maximum decoding length reached, mark all batches as finished.
# This requires special handling because performing lookup on
# position embeddings with `time + 1` may result in IndexError.
finished = torch.ones_like(sample_ids, dtype=torch_bool)
# Since `next_inputs` will not be used, simply create a null tensor.
next_inputs = torch.empty(0)
else:
finished, next_inputs = helper.next_inputs(
self.embed_tokens, time, outputs, sample_ids)
next_state = state
outputs = TransformerDecoderOutput(
logits=outputs,
sample_id=sample_ids)
return outputs, next_state, next_inputs, finished
return outputs, next_state

def next_inputs(self, helper: Helper, time: int,
outputs: TransformerDecoderOutput) -> \
Tuple[torch.Tensor, torch.ByteTensor]:
finished, next_inputs = helper.next_inputs(
self.embed_tokens, time, outputs.logits, outputs.sample_id)
return next_inputs, finished

def finalize(self, # type: ignore
outputs: TransformerDecoderOutput,
Expand Down
20 changes: 11 additions & 9 deletions texar/torch/modules/decoders/xlnet_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,9 @@ def initialize(self,
self.embed_tokens, inputs, sequence_length)
return initial_finished, initial_inputs, initial_state

def step(self,
helper: Helper,
time: int,
inputs: torch.Tensor,
state: Optional[State]) \
-> Tuple[Output, Optional[State], torch.Tensor, torch.ByteTensor]:
def step(self, helper: Helper, time: int, inputs: torch.Tensor,
state: Optional[State]) -> \
Tuple[Output, Optional[State]]:
self._state_previous_inputs.append(inputs)
if self._state_recompute_memory:
net_output, memory = self._forward(
Expand All @@ -281,10 +278,15 @@ def step(self,
logits = F.linear(net_output, self.word_embed.weight, self.lm_bias)
logits = logits[:, -1]
sample_ids = helper.sample(time=time, outputs=logits)
(finished, next_inputs) = helper.next_inputs(
self.embed_tokens, time, logits, sample_ids)
outputs = XLNetDecoderOutput(logits=logits, sample_id=sample_ids)
return outputs, memory, next_inputs, finished
return outputs, memory

def next_inputs(self, helper: Helper, time: int,
outputs: Output) -> \
Tuple[torch.Tensor, torch.ByteTensor]:
finished, next_inputs = helper.next_inputs(
self.embed_tokens, time, outputs.logits, outputs.sample_id)
return next_inputs, finished

def finalize(self, outputs, final_state, sequence_lengths):
del self._state_cache_len
Expand Down

0 comments on commit aedf40a

Please sign in to comment.