Skip to content

Commit

Permalink
Add Zoneout support to IndRNN (PyTorch).
Browse files Browse the repository at this point in the history
Issue: #7
  • Loading branch information
sharvil committed Mar 25, 2020
1 parent e7cd04a commit f6ca8c2
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 42 deletions.
16 changes: 13 additions & 3 deletions frameworks/pytorch/indrnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,25 @@ using torch::Tensor;

Tensor indrnn_forward(
bool training,
float zoneout_prob,
Tensor x,
Tensor h0,
Tensor kernel,
Tensor recurrent_scale,
Tensor bias) {
Tensor bias,
Tensor zoneout_mask) {
const auto time_steps = x.size(0);
const auto batch_size = x.size(1);
const auto input_size = x.size(2);
const auto hidden_size = recurrent_scale.size(0);
const bool has_zoneout = zoneout_prob && zoneout_mask.size(0);

CHECK_INPUT(x);
CHECK_INPUT(h0);
CHECK_INPUT(kernel);
CHECK_INPUT(recurrent_scale);
CHECK_INPUT(bias);
CHECK_INPUT(zoneout_mask);

Tensor output = torch::empty({ time_steps + 1, batch_size, hidden_size }, at::kCUDA);
Tensor workspace = torch::empty({ time_steps, batch_size, hidden_size }, at::kCUDA);
Expand All @@ -65,7 +69,9 @@ Tensor indrnn_forward(
bias.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
workspace.data_ptr<scalar_t>());
workspace.data_ptr<scalar_t>(),
has_zoneout ? zoneout_prob : 0.0f,
has_zoneout ? zoneout_mask.data_ptr<scalar_t>() : nullptr);
}));

return output;
Expand All @@ -76,17 +82,20 @@ std::vector<Tensor> indrnn_backward(
Tensor kernel_t,
Tensor recurrent_scale,
Tensor bias,
Tensor zoneout_mask,
Tensor h,
Tensor dh_new) {
const auto input_size = x_t.size(0);
const auto time_steps = x_t.size(1);
const auto batch_size = x_t.size(2);
const auto hidden_size = recurrent_scale.size(0);
const bool has_zoneout = !!zoneout_mask.size(0);

CHECK_INPUT(x_t);
CHECK_INPUT(kernel_t);
CHECK_INPUT(recurrent_scale);
CHECK_INPUT(bias);
CHECK_INPUT(zoneout_mask);
CHECK_INPUT(h);
CHECK_INPUT(dh_new);

Expand Down Expand Up @@ -117,7 +126,8 @@ std::vector<Tensor> indrnn_backward(
du.data_ptr<scalar_t>(),
db.data_ptr<scalar_t>(),
dh.data_ptr<scalar_t>(),
workspace.data_ptr<scalar_t>());
workspace.data_ptr<scalar_t>(),
has_zoneout ? zoneout_mask.data_ptr<scalar_t>() : nullptr);
}));

return { dx, dh, dW, du, db };
Expand Down
40 changes: 31 additions & 9 deletions frameworks/pytorch/indrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,32 @@
#@torch.jit.script
def IndRNNScript(
training: bool,
zoneout_prob: float,
input,
h0,
kernel,
recurrent_scale,
bias):
bias,
zoneout_mask):
time_steps = input.shape[0]

h = [h0]
Wx = input @ kernel + bias
for t in range(time_steps):
h.append(torch.tanh(Wx[t] + h[-1] * recurrent_scale))
if zoneout_prob:
if training:
h[-1] = (h[-1] - h[-2]) * zoneout_mask[t] + h[-2]
else:
h[-1] = zoneout_prob * h[-2] + (1 - zoneout_prob) * h[-1]
h = torch.stack(h)
return h


class IndRNNFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, training, *inputs):
h = LIB.indrnn_forward(training, *inputs)
def forward(ctx, training, zoneout_prob, *inputs):
h = LIB.indrnn_forward(training, zoneout_prob, *inputs)
ctx.save_for_backward(inputs[0], *inputs[2:], h)
ctx.training = training
return h
Expand All @@ -58,20 +65,25 @@ def backward(ctx, grad_h):
saved[0] = saved[0].permute(2, 0, 1).contiguous()
saved[1] = saved[1].permute(1, 0).contiguous()
grads = LIB.indrnn_backward(*saved, grad_h.contiguous())
return (None, *grads)
return (None, None, *grads, None)


class IndRNN(nn.Module):
def __init__(
self,
input_size,
hidden_size,
batch_first=False):
batch_first=False,
zoneout=0.0):
super(IndRNN, self).__init__()

if zoneout < 0 or zoneout > 1:
raise ValueError('IndRNN: zoneout must be in [0.0, 1.0]')

self.input_size = input_size
self.hidden_size = hidden_size
self.batch_first = batch_first
self.zoneout = zoneout

self.kernel = nn.Parameter(torch.empty(input_size, hidden_size))
self.recurrent_scale = nn.Parameter(torch.empty(hidden_size))
Expand All @@ -87,14 +99,20 @@ def forward(self, input, state=None, lengths=None):
if self.batch_first:
input = input.permute(1, 0, 2)

if self.zoneout:
zoneout_mask = input.new_empty(input.shape[0], input.shape[1], self.hidden_size)
zoneout_mask.bernoulli_(1.0 - self.zoneout)
else:
zoneout_mask = input.new_empty(0, 0, 0)

if state is None:
h0 = input.new_zeros(input.shape[1], self.hidden_size)
elif state.shape[0] != 1:
raise ValueError('initial state for IndRNN must have leading dimesion of 1')
else:
h0 = state[0]

h = self._impl(input, h0)
h = self._impl(input, h0, zoneout_mask)

if lengths is not None:
cols = range(h.size(1))
Expand All @@ -108,7 +126,7 @@ def forward(self, input, state=None, lengths=None):

return output, state

def _impl(self, input, state):
def _impl(self, input, state, zoneout_mask):
tensors = [input, self.kernel, self.recurrent_scale, self.bias]
is_cuda = [tensor.is_cuda for tensor in tensors]
if any(is_cuda) and not all(is_cuda):
Expand All @@ -117,16 +135,20 @@ def _impl(self, input, state):
if all(is_cuda):
return IndRNNFunction.apply(
self.training,
self.zoneout,
input.contiguous(),
state.contiguous(),
self.kernel.contiguous(),
self.recurrent_scale.contiguous(),
self.bias.contiguous())
self.bias.contiguous(),
zoneout_mask.contiguous())
else:
return IndRNNScript(
self.training,
self.zoneout,
input.contiguous(),
state.contiguous(),
self.kernel.contiguous(),
self.recurrent_scale.contiguous(),
self.bias.contiguous())
self.bias.contiguous(),
zoneout_mask.contiguous())
7 changes: 5 additions & 2 deletions lib/haste/indrnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ class ForwardPass {
const T* b,
const T* x,
T* h,
T* workspace);
T* workspace,
const float zoneout_prob,
const T* zoneout_mask);

private:
struct private_data;
Expand Down Expand Up @@ -72,7 +74,8 @@ class BackwardPass {
T* du,
T* db,
T* dh,
T* workspace);
T* workspace,
const T* zoneout_mask);

private:
struct private_data;
Expand Down
57 changes: 41 additions & 16 deletions lib/indrnn_backward_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

namespace {

template<typename T>
template<typename T, bool ApplyZoneout>
__global__
void IndrnnBwdOps(
const int batch_size,
Expand All @@ -34,7 +34,8 @@ void IndrnnBwdOps(
T* du_out,
T* db_out,
T* dh_inout,
T* dk_out) {
T* dk_out,
const T* zoneout_mask) {
const int row = blockDim.x * blockIdx.x + threadIdx.x;
const int col = blockDim.y * blockIdx.y + threadIdx.y;

Expand All @@ -43,11 +44,18 @@ void IndrnnBwdOps(

const int idx = col * hidden_size + row;

const T dh_total = dh_new[idx] + dh_inout[idx];
T dh_total = dh_new[idx] + dh_inout[idx];
T dh = static_cast<T>(0.0);
if (ApplyZoneout) {
const T mask = zoneout_mask[idx];
dh = (static_cast<T>(1.0) - mask) * dh_total;
dh_total = mask * dh_total;
}

const T dk = d_tanh(h[idx]) * dh_total;

dk_out[idx] = dk;
dh_inout[idx] = u[row] * dk;
dh_inout[idx] = dh + u[row] * dk;
atomicAdd(&du_out[row], h_prev[idx] * dk);
atomicAdd(&db_out[row], dk);
}
Expand Down Expand Up @@ -101,7 +109,8 @@ void BackwardPass<T>::Run(
T* du,
T* db,
T* dh,
T* workspace) {
T* workspace,
const T* zoneout_mask) {
const T alpha = static_cast<T>(1.0);
const T beta = static_cast<T>(0.0);

Expand All @@ -117,17 +126,33 @@ void BackwardPass<T>::Run(
(batch_size + blockDim.y - 1) / blockDim.y);
const int NH = batch_size * hidden_size;
for (int i = steps - 1; i >= 0; --i) {
IndrnnBwdOps<T><<<gridDim, blockDim, 0, stream>>>(
batch_size,
hidden_size,
u,
h + i * NH,
h + (i + 1) * NH,
dh_new + (i + 1) * NH,
du,
db,
dh,
workspace + i * NH);
if (zoneout_mask) {
IndrnnBwdOps<T, true><<<gridDim, blockDim, 0, stream>>>(
batch_size,
hidden_size,
u,
h + i * NH,
h + (i + 1) * NH,
dh_new + (i + 1) * NH,
du,
db,
dh,
workspace + i * NH,
zoneout_mask + i * NH);
} else {
IndrnnBwdOps<T, false><<<gridDim, blockDim, 0, stream>>>(
batch_size,
hidden_size,
u,
h + i * NH,
h + (i + 1) * NH,
dh_new + (i + 1) * NH,
du,
db,
dh,
workspace + i * NH,
nullptr);
}
}

cudaStream_t save_stream;
Expand Down
Loading

0 comments on commit f6ca8c2

Please sign in to comment.