diff --git a/distiller/modules/rnn.py b/distiller/modules/rnn.py index 7471f32fd..633c088f4 100644 --- a/distiller/modules/rnn.py +++ b/distiller/modules/rnn.py @@ -22,6 +22,7 @@ __all__ = ['DistillerLSTMCell', 'DistillerLSTM'] + class DistillerLSTMCell(nn.Module): """ A single LSTM block. @@ -231,9 +232,9 @@ def __init__(self, input_size, hidden_size, num_layers, bias=True, batch_first=F def forward(self, x, h=None): is_packed_seq = isinstance(x, nn.utils.rnn.PackedSequence) if is_packed_seq: - x, lengths = nn.utils.rnn.pad_packed_sequence(x, self.batch_first) + return self.packed_sequence_forward(x, h) - elif self.batch_first: + if self.batch_first: # Transpose to sequence_first format x = x.transpose(0, 1) x_bsz = x.size(1) @@ -242,14 +243,40 @@ def forward(self, x, h=None): h = self.init_hidden(x_bsz) y, h = self.forward_fn(x, h) - if is_packed_seq: - y = nn.utils.rnn.pack_padded_sequence(y, lengths, self.batch_first) - elif self.batch_first: + if self.batch_first: # Transpose back to batch_first format y = y.transpose(0, 1) return y, h + def packed_sequence_forward(self, x, h=None): + # Packed sequence treatment - + # the sequences are not of the same size, hence + # we split the padded tensor into the sequences. + # we take the sequence from each row in the batch. + x, lengths = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) + x_bsz = x.size(0) + if h is None: + h = self.init_hidden(x_bsz) + y_results = [] + h_results = [] + for i, (sequence, seq_len) in enumerate(zip(x, lengths)): + # Take the previous state according to the current batch. + # we unsqueeze to have a 3D tensor + h_current = (h[0][:, i, :].unsqueeze(1), h[1][:, i, :].unsqueeze(1)) + # Take only the relevant timesteps according to seq_len + sequence = sequence[:seq_len].unsqueeze(1) # sequence.shape = (seq_len, batch_size=1, input_dim) + # forward pass: + y, h_current = self.forward_fn(sequence, h_current) + # sequeeze back the batch into a single sequence + y_results.append(y.squeeze(1)) + h_results.append(h_current) + # our result is a packed sequence + y = nn.utils.rnn.pack_sequence(y_results) + # concat hidden states per batches + h = torch.cat([t[0] for t in h_results], dim=1), torch.cat([t[1] for t in h_results], dim=1) + return y, h + def process_layer_wise(self, x, h): results = [] for step in x: @@ -304,6 +331,9 @@ def _layer_chain_unidirectional(self, step, h): """ Process a single timestep through the entire unidirectional layer chain. """ + step_bsz = step.size(0) + if h is None: + h = self.init_hidden(step_bsz) h_all, c_all = h h_result = [] out = step diff --git a/tests/test_lstm_impl.py b/tests/test_lstm_impl.py index 319015d69..616aa6a14 100644 --- a/tests/test_lstm_impl.py +++ b/tests/test_lstm_impl.py @@ -4,8 +4,11 @@ from distiller.modules import DistillerLSTM, DistillerLSTMCell import torch import torch.nn as nn +from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence, PackedSequence +from torch.testing import assert_allclose -ACCEPTABLE_ERROR = 5e-5 +ATOL = 5e-5 +RTOL = 1e-3 BATCH_SIZE = 32 SEQUENCE_SIZE = 35 @@ -45,11 +48,15 @@ def test_conversion(): def assert_output(out_true, out_pred): y_t, h_t = out_true y_p, h_p = out_pred - assert (y_t - y_p).abs().max().item() < ACCEPTABLE_ERROR + if isinstance(y_t, PackedSequence): + y_t, lenghts_t = pad_packed_sequence(y_t) + y_p, lenghts_p = pad_packed_sequence(y_p) + assert all(lenghts_t == lenghts_p) + assert_allclose(y_p, y_t, RTOL, ATOL) h_h_t, h_c_t = h_t h_h_p, h_c_p = h_p - assert (h_h_t - h_h_p).abs().max().item() < ACCEPTABLE_ERROR - assert (h_c_t - h_c_p).abs().max().item() < ACCEPTABLE_ERROR + assert_allclose(h_h_p, h_h_t, RTOL, ATOL) + assert_allclose(h_c_p, h_c_t, RTOL, ATOL) @pytest.fixture(name='bidirectional', params=[False, True], ids=['bidirectional_off', 'bidirectional_on']) @@ -91,3 +98,24 @@ def test_forward_lstm(input_size, hidden_size, num_layers, bidirectional): out_true = lstm(x, h) out_pred = lstm_man(x, h) assert_output(out_true, out_pred) + + +@pytest.mark.parametrize( + "input_size, hidden_size, num_layers, input_lengths", + [ + (1, 1, 2, [5, 4, 3]), + (3, 5, 7, [20, 15, 5]), + (500, 500, 5, [50, 35, 25]) + ] +) +def test_packed_sequence(input_size, hidden_size, num_layers, input_lengths, bidirectional): + lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional) + lstm_man = DistillerLSTM.from_pytorch_impl(lstm) + lstm.eval() + lstm_man.eval() + + h = lstm_man.init_hidden(BATCH_SIZE) + x = pack_sequence([torch.rand(length, input_size) for length in input_lengths]) + out_true = lstm(x) + out_pred = lstm_man(x) + assert_output(out_true, out_pred)