Skip to content

Commit

Permalink
Multilayer decoder for semantic parsing framework (allenai#1902)
Browse files Browse the repository at this point in the history
* first pass

* clean up

* rename

* keep final states false

* fix test

* use lstm cell if only 1 layer

* comment

* remove commented out code

* lint, make model compatible with 1 layer decoder

* test multi layers
  • Loading branch information
kl2806 authored Oct 13, 2018
1 parent 7a707ea commit 3753b0b
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 25 deletions.
26 changes: 19 additions & 7 deletions allennlp/models/semantic_parsing/atis/atis_semantic_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self,
input_attention: Attention,
add_action_bias: bool = True,
training_beam_size: int = None,
decoder_num_layers: int = 1,
dropout: float = 0.0,
rule_namespace: str = 'rule_labels',
database_file='/atis/atis.db') -> None:
Expand Down Expand Up @@ -108,6 +109,7 @@ def __init__(self,

self._num_entity_types = 2 # TODO(kevin): get this in a more principled way somehow?
self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)
self._decoder_num_layers = decoder_num_layers

self._beam_search = decoder_beam_search
self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
Expand All @@ -116,7 +118,8 @@ def __init__(self,
input_attention=input_attention,
predict_start_type_separately=False,
add_action_bias=self._add_action_bias,
dropout=dropout)
dropout=dropout,
num_layers=self._decoder_num_layers)

@overrides
def forward(self, # type: ignore
Expand Down Expand Up @@ -278,12 +281,21 @@ def _get_initial_state(self,
utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
initial_rnn_state = []
for i in range(batch_size):
initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
memory_cell[i],
self._first_action_embedding,
self._first_attended_utterance,
encoder_output_list,
utterance_mask_list))
if self._decoder_num_layers > 1:
initial_rnn_state.append(RnnStatelet(final_encoder_output[i].repeat(self._decoder_num_layers, 1),
memory_cell[i].repeat(self._decoder_num_layers, 1),
self._first_action_embedding,
self._first_attended_utterance,
encoder_output_list,
utterance_mask_list))
else:
initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
memory_cell[i],
self._first_action_embedding,
self._first_attended_utterance,
encoder_output_list,
utterance_mask_list))


initial_grammar_state = [self._create_grammar_state(worlds[i],
actions[i],
Expand Down
6 changes: 4 additions & 2 deletions allennlp/state_machines/states/rnn_statelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ class RnnStatelet:
Parameters
----------
hidden_state : ``torch.Tensor``
This holds the LSTM hidden state, with shape ``(decoder_output_dim,)``.
This holds the LSTM hidden state, with shape ``(decoder_output_dim,)`` if the decoder
has 1 layer and ``(num_layers, decoder_output_dim)`` otherwise.
memory_cell : ``torch.Tensor``
This holds the LSTM memory cell, with shape ``(decoder_output_dim,)``.
This holds the LSTM memory cell, with shape ``(decoder_output_dim,)`` if the decoder has
1 layer and ``(num_layers, decoder_output_dim)`` otherwise.
previous_action_embedding : ``torch.Tensor``
This holds the embedding for the action we took at the last timestep (which gets input to
the decoder). Has shape ``(action_embedding_dim,)``.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from overrides import overrides

import torch
from torch.nn.modules.rnn import LSTMCell
from torch.nn.modules.rnn import LSTM, LSTMCell
from torch.nn.modules.linear import Linear

from allennlp.modules import Attention
Expand Down Expand Up @@ -47,6 +47,8 @@ class BasicTransitionFunction(TransitionFunction[GrammarBasedState]):
gets used when predicting the next action. We add a dimension of ones to our predicted
action vector in this case to account for that.
dropout : ``float`` (optional, default=0.0)
num_layers: ``int``, (optional, default=1)
The number of layers in the decoder LSTM.
"""
def __init__(self,
encoder_output_dim: int,
Expand All @@ -56,11 +58,13 @@ def __init__(self,
predict_start_type_separately: bool = True,
num_start_types: int = None,
add_action_bias: bool = True,
dropout: float = 0.0) -> None:
dropout: float = 0.0,
num_layers: int = 1) -> None:
super().__init__()
self._input_attention = input_attention
self._add_action_bias = add_action_bias
self._activation = activation
self._num_layers = num_layers

self._predict_start_type_separately = predict_start_type_separately
if predict_start_type_separately:
Expand All @@ -82,8 +86,12 @@ def __init__(self,
# hidden state. Then we concatenate those with the decoder state and project to
# `action_embedding_dim` to make a prediction.
self._output_projection_layer = Linear(output_dim + encoder_output_dim, action_embedding_dim)

self._decoder_cell = LSTMCell(input_dim, output_dim)
if self._num_layers > 1:
self._decoder_cell = LSTM(input_dim, output_dim, self._num_layers)
else:
# We use a ``LSTMCell`` if we just have one layer because it is slightly faster since we are
# just running the LSTM for one step each time.
self._decoder_cell = LSTMCell(input_dim, output_dim)

if dropout > 0:
self._dropout = torch.nn.Dropout(p=dropout)
Expand Down Expand Up @@ -128,26 +136,41 @@ def _update_decoder_state(self, state: GrammarBasedState) -> Dict[str, torch.Ten

group_size = len(state.batch_indices)
attended_question = torch.stack([rnn_state.attended_input for rnn_state in state.rnn_state])
hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state])
memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state])
if self._num_layers > 1:
hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state], 1)
memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state], 1)
else:
hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state])
memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state])

previous_action_embedding = torch.stack([rnn_state.previous_action_embedding
for rnn_state in state.rnn_state])

# (group_size, decoder_input_dim)
projected_input = self._input_projection_layer(torch.cat([attended_question,
previous_action_embedding], -1))
decoder_input = self._activation(projected_input)

hidden_state, memory_cell = self._decoder_cell(decoder_input, (hidden_state, memory_cell))
if self._num_layers > 1:
_, (hidden_state, memory_cell) = self._decoder_cell(decoder_input.unsqueeze(0),
(hidden_state, memory_cell))
else:
hidden_state, memory_cell = self._decoder_cell(decoder_input, (hidden_state, memory_cell))
hidden_state = self._dropout(hidden_state)

# (group_size, encoder_output_dim)
encoder_outputs = torch.stack([state.rnn_state[0].encoder_outputs[i] for i in state.batch_indices])
encoder_output_mask = torch.stack([state.rnn_state[0].encoder_output_mask[i] for i in state.batch_indices])
attended_question, attention_weights = self.attend_on_question(hidden_state,
encoder_outputs,
encoder_output_mask)
action_query = torch.cat([hidden_state, attended_question], dim=-1)

if self._num_layers > 1:
attended_question, attention_weights = self.attend_on_question(hidden_state[-1],
encoder_outputs,
encoder_output_mask)
action_query = torch.cat([hidden_state[-1], attended_question], dim=-1)
else:
attended_question, attention_weights = self.attend_on_question(hidden_state,
encoder_outputs,
encoder_output_mask)
action_query = torch.cat([hidden_state, attended_question], dim=-1)

# (group_size, action_embedding_dim)
projected_query = self._activation(self._output_projection_layer(action_query))
Expand Down Expand Up @@ -227,8 +250,13 @@ def _construct_next_states(self,
# each time is more expensive than doing it once upfront. These three lines give about a
# 10% speedup in training time.
group_size = len(state.batch_indices)
hidden_state = [x.squeeze(0) for x in updated_rnn_state['hidden_state'].chunk(group_size, 0)]
memory_cell = [x.squeeze(0) for x in updated_rnn_state['memory_cell'].chunk(group_size, 0)]

chunk_index = 1 if self._num_layers > 1 else 0
hidden_state = [x.squeeze(chunk_index)
for x in updated_rnn_state['hidden_state'].chunk(group_size, chunk_index)]
memory_cell = [x.squeeze(chunk_index)
for x in updated_rnn_state['memory_cell'].chunk(group_size, chunk_index)]

attended_question = [x.squeeze(0) for x in updated_rnn_state['attended_question'].chunk(group_size, 0)]

def make_state(group_index: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class LinkingTransitionFunction(BasicTransitionFunction):
actions given the hidden state at every timestep of decoding, instead of concatenating the
logits for both (where the logits may not be compatible with each other).
dropout : ``float`` (optional, default=0.0)
num_layers: ``int`` (optional, default=1)
The number of layers in the decoder LSTM.
"""
def __init__(self,
encoder_output_dim: int,
Expand All @@ -61,15 +63,17 @@ def __init__(self,
num_start_types: int = None,
add_action_bias: bool = True,
mixture_feedforward: FeedForward = None,
dropout: float = 0.0) -> None:
dropout: float = 0.0,
num_layers: int = 1) -> None:
super().__init__(encoder_output_dim=encoder_output_dim,
action_embedding_dim=action_embedding_dim,
input_attention=input_attention,
num_start_types=num_start_types,
activation=activation,
predict_start_type_separately=predict_start_type_separately,
add_action_bias=add_action_bias,
dropout=dropout)
dropout=dropout,
num_layers=num_layers)
self._mixture_feedforward = mixture_feedforward

if mixture_feedforward is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"decoder_beam_search": {
"beam_size": 5
},
"decoder_num_layers": 2,
"max_decoding_steps": 10,
"input_attention": {"type": "dot_product"},
"dropout": 0.5,
Expand Down

0 comments on commit 3753b0b

Please sign in to comment.