Skip to content

Commit

Permalink
Added PackedSequence functionality (IntelLabs#236)
Browse files Browse the repository at this point in the history
* Update test_lstm_impl.py

* Added PackedSequence functionality

* Refactored forward implementation
  • Loading branch information
levzlotnik authored Apr 30, 2019
1 parent 8c5de42 commit 92fd001
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 9 deletions.
40 changes: 35 additions & 5 deletions distiller/modules/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

__all__ = ['DistillerLSTMCell', 'DistillerLSTM']


class DistillerLSTMCell(nn.Module):
"""
A single LSTM block.
Expand Down Expand Up @@ -231,9 +232,9 @@ def __init__(self, input_size, hidden_size, num_layers, bias=True, batch_first=F
def forward(self, x, h=None):
is_packed_seq = isinstance(x, nn.utils.rnn.PackedSequence)
if is_packed_seq:
x, lengths = nn.utils.rnn.pad_packed_sequence(x, self.batch_first)
return self.packed_sequence_forward(x, h)

elif self.batch_first:
if self.batch_first:
# Transpose to sequence_first format
x = x.transpose(0, 1)
x_bsz = x.size(1)
Expand All @@ -242,14 +243,40 @@ def forward(self, x, h=None):
h = self.init_hidden(x_bsz)

y, h = self.forward_fn(x, h)
if is_packed_seq:
y = nn.utils.rnn.pack_padded_sequence(y, lengths, self.batch_first)

elif self.batch_first:
if self.batch_first:
# Transpose back to batch_first format
y = y.transpose(0, 1)
return y, h

def packed_sequence_forward(self, x, h=None):
# Packed sequence treatment -
# the sequences are not of the same size, hence
# we split the padded tensor into the sequences.
# we take the sequence from each row in the batch.
x, lengths = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
x_bsz = x.size(0)
if h is None:
h = self.init_hidden(x_bsz)
y_results = []
h_results = []
for i, (sequence, seq_len) in enumerate(zip(x, lengths)):
# Take the previous state according to the current batch.
# we unsqueeze to have a 3D tensor
h_current = (h[0][:, i, :].unsqueeze(1), h[1][:, i, :].unsqueeze(1))
# Take only the relevant timesteps according to seq_len
sequence = sequence[:seq_len].unsqueeze(1) # sequence.shape = (seq_len, batch_size=1, input_dim)
# forward pass:
y, h_current = self.forward_fn(sequence, h_current)
# sequeeze back the batch into a single sequence
y_results.append(y.squeeze(1))
h_results.append(h_current)
# our result is a packed sequence
y = nn.utils.rnn.pack_sequence(y_results)
# concat hidden states per batches
h = torch.cat([t[0] for t in h_results], dim=1), torch.cat([t[1] for t in h_results], dim=1)
return y, h

def process_layer_wise(self, x, h):
results = []
for step in x:
Expand Down Expand Up @@ -304,6 +331,9 @@ def _layer_chain_unidirectional(self, step, h):
"""
Process a single timestep through the entire unidirectional layer chain.
"""
step_bsz = step.size(0)
if h is None:
h = self.init_hidden(step_bsz)
h_all, c_all = h
h_result = []
out = step
Expand Down
36 changes: 32 additions & 4 deletions tests/test_lstm_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
from distiller.modules import DistillerLSTM, DistillerLSTMCell
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence, PackedSequence
from torch.testing import assert_allclose

ACCEPTABLE_ERROR = 5e-5
ATOL = 5e-5
RTOL = 1e-3
BATCH_SIZE = 32
SEQUENCE_SIZE = 35

Expand Down Expand Up @@ -45,11 +48,15 @@ def test_conversion():
def assert_output(out_true, out_pred):
y_t, h_t = out_true
y_p, h_p = out_pred
assert (y_t - y_p).abs().max().item() < ACCEPTABLE_ERROR
if isinstance(y_t, PackedSequence):
y_t, lenghts_t = pad_packed_sequence(y_t)
y_p, lenghts_p = pad_packed_sequence(y_p)
assert all(lenghts_t == lenghts_p)
assert_allclose(y_p, y_t, RTOL, ATOL)
h_h_t, h_c_t = h_t
h_h_p, h_c_p = h_p
assert (h_h_t - h_h_p).abs().max().item() < ACCEPTABLE_ERROR
assert (h_c_t - h_c_p).abs().max().item() < ACCEPTABLE_ERROR
assert_allclose(h_h_p, h_h_t, RTOL, ATOL)
assert_allclose(h_c_p, h_c_t, RTOL, ATOL)


@pytest.fixture(name='bidirectional', params=[False, True], ids=['bidirectional_off', 'bidirectional_on'])
Expand Down Expand Up @@ -91,3 +98,24 @@ def test_forward_lstm(input_size, hidden_size, num_layers, bidirectional):
out_true = lstm(x, h)
out_pred = lstm_man(x, h)
assert_output(out_true, out_pred)


@pytest.mark.parametrize(
"input_size, hidden_size, num_layers, input_lengths",
[
(1, 1, 2, [5, 4, 3]),
(3, 5, 7, [20, 15, 5]),
(500, 500, 5, [50, 35, 25])
]
)
def test_packed_sequence(input_size, hidden_size, num_layers, input_lengths, bidirectional):
lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
lstm_man = DistillerLSTM.from_pytorch_impl(lstm)
lstm.eval()
lstm_man.eval()

h = lstm_man.init_hidden(BATCH_SIZE)
x = pack_sequence([torch.rand(length, input_size) for length in input_lengths])
out_true = lstm(x)
out_pred = lstm_man(x)
assert_output(out_true, out_pred)

0 comments on commit 92fd001

Please sign in to comment.