Skip to content

Commit

Permalink
Performance improvement to IndRNN (~20% on forward+backward).
Browse files Browse the repository at this point in the history
This change consists of two optimizations:
  1. Perform all pointwise operations in a single CUDA kernel
     (even across time steps).
  2. Reduce memory accesses in the backward pass by accumulating
     sums in registers and writing back to global memory at the
     end of the kernel.

The result is a ~20% speed improvement for a single training
iteration that consists of a forward and backward pass.

Issue: #7
  • Loading branch information
sharvil committed Apr 1, 2020
1 parent feb4344 commit 1afeb86
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 97 deletions.
93 changes: 53 additions & 40 deletions lib/indrnn_backward_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace {
template<typename T, bool ApplyZoneout>
__global__
void IndrnnBwdOps(
const int steps,
const int batch_size,
const int hidden_size,
const T* u,
Expand All @@ -42,22 +43,34 @@ void IndrnnBwdOps(
if (row >= hidden_size || col >= batch_size)
return;

const int NH = batch_size * hidden_size;
const int idx = col * hidden_size + row;

T dh_total = dh_new[idx] + dh_inout[idx];
T dh = static_cast<T>(0.0);
if (ApplyZoneout) {
const T mask = zoneout_mask[idx];
dh = (static_cast<T>(1.0) - mask) * dh_total;
dh_total = mask * dh_total;
}
const T u_row = u[row];
T dh_inout_idx = dh_inout[idx];
T du_sum = static_cast<T>(0.0);
T db_sum = static_cast<T>(0.0);

for (int i = (steps - 1) * NH; i >= 0; i -= NH) {
T dh_total = dh_new[idx + i] + dh_inout_idx;
T dh = static_cast<T>(0.0);
if (ApplyZoneout) {
const T mask = zoneout_mask[idx + i];
dh = (static_cast<T>(1.0) - mask) * dh_total;
dh_total = mask * dh_total;
}

const T dk = d_tanh(h[idx]) * dh_total;
const T dk = d_tanh(h[idx + i]) * dh_total;

dk_out[idx] = dk;
dh_inout[idx] = dh + u[row] * dk;
atomicAdd(&du_out[row], h_prev[idx] * dk);
atomicAdd(&db_out[row], dk);
dk_out[idx + i] = dk;
dh_inout_idx = dh + u_row * dk;
du_sum += h_prev[idx + i] * dk;
db_sum += dk;
}

dh_inout[idx] = dh_inout_idx;
atomicAdd(&du_out[row], du_sum);
atomicAdd(&db_out[row], db_sum);
}

} // anonymous namespace
Expand Down Expand Up @@ -125,34 +138,34 @@ void BackwardPass<T>::Run(
(hidden_size + blockDim.x - 1) / blockDim.x,
(batch_size + blockDim.y - 1) / blockDim.y);
const int NH = batch_size * hidden_size;
for (int i = steps - 1; i >= 0; --i) {
if (zoneout_mask) {
IndrnnBwdOps<T, true><<<gridDim, blockDim, 0, stream>>>(
batch_size,
hidden_size,
u,
h + i * NH,
h + (i + 1) * NH,
dh_new + (i + 1) * NH,
du,
db,
dh,
workspace + i * NH,
zoneout_mask + i * NH);
} else {
IndrnnBwdOps<T, false><<<gridDim, blockDim, 0, stream>>>(
batch_size,
hidden_size,
u,
h + i * NH,
h + (i + 1) * NH,
dh_new + (i + 1) * NH,
du,
db,
dh,
workspace + i * NH,
nullptr);
}
if (zoneout_mask) {
IndrnnBwdOps<T, true><<<gridDim, blockDim, 0, stream>>>(
steps,
batch_size,
hidden_size,
u,
h,
h + NH,
dh_new + NH,
du,
db,
dh,
workspace,
zoneout_mask);
} else {
IndrnnBwdOps<T, false><<<gridDim, blockDim, 0, stream>>>(
steps,
batch_size,
hidden_size,
u,
h,
h + NH,
dh_new + NH,
du,
db,
dh,
workspace,
nullptr);
}

cudaStream_t save_stream;
Expand Down
123 changes: 66 additions & 57 deletions lib/indrnn_forward_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace {
template<typename T, bool Training, bool ApplyZoneout>
__global__
void IndrnnFwdOps(
const int steps,
const int batch_size,
const int hidden_size,
const T* Wx,
Expand All @@ -41,18 +42,24 @@ void IndrnnFwdOps(
return;

const int idx = col * hidden_size + row;
const T a = Wx[idx] + u[row] * h[idx] + b[row];
T cur_h_value = tanh(a);
const int NH = batch_size * hidden_size;
const T u_row = u[row];
const T b_row = b[row];

if (ApplyZoneout) {
if (Training) {
cur_h_value = (cur_h_value - h[idx]) * zoneout_mask[idx] + h[idx];
} else {
cur_h_value = (zoneout_prob * h[idx]) + ((1.0f - zoneout_prob) * cur_h_value);
for (int i = 0; i < steps * NH; i += NH) {
const T a = Wx[idx + i] + u_row * h[idx + i] + b_row;
T cur_h_value = tanh(a);

if (ApplyZoneout) {
if (Training) {
cur_h_value = (cur_h_value - h[idx + i]) * zoneout_mask[idx + i] + h[idx + i];
} else {
cur_h_value = (zoneout_prob * h[idx + i]) + ((1.0f - zoneout_prob) * cur_h_value);
}
}
}

h_out[idx] = cur_h_value;
h_out[idx + i] = cur_h_value;
}
}

} // anonymous namespace
Expand Down Expand Up @@ -132,55 +139,57 @@ void ForwardPass<T>::Run(
(hidden_size + blockDim.x - 1) / blockDim.x,
(batch_size + blockDim.y - 1) / blockDim.y);
const int NH = batch_size * hidden_size;
for (int i = 0; i < steps; ++i) {
if (training) {
if (zoneout_prob && zoneout_mask) {
IndrnnFwdOps<T, true, true><<<gridDim, blockDim, 0, stream>>>(
batch_size,
hidden_size,
workspace + i * NH,
u,
b,
h + i * NH,
h + (i + 1) * NH,
zoneout_prob,
zoneout_mask + i * NH);
} else {
IndrnnFwdOps<T, true, false><<<gridDim, blockDim, 0, stream>>>(
batch_size,
hidden_size,
workspace + i * NH,
u,
b,
h + i * NH,
h + (i + 1) * NH,
0.0f,
nullptr);
}
if (training) {
if (zoneout_prob && zoneout_mask) {
IndrnnFwdOps<T, true, true><<<gridDim, blockDim, 0, stream>>>(
steps,
batch_size,
hidden_size,
workspace,
u,
b,
h,
h + NH,
zoneout_prob,
zoneout_mask);
} else {
if (zoneout_prob && zoneout_mask) {
IndrnnFwdOps<T, false, true><<<gridDim, blockDim, 0, stream>>>(
batch_size,
hidden_size,
workspace + i * NH,
u,
b,
h + i * NH,
h + (i + 1) * NH,
zoneout_prob,
zoneout_mask + i * NH);
} else {
IndrnnFwdOps<T, false, false><<<gridDim, blockDim, 0, stream>>>(
batch_size,
hidden_size,
workspace + i * NH,
u,
b,
h + i * NH,
h + (i + 1) * NH,
0.0f,
nullptr);
}
IndrnnFwdOps<T, true, false><<<gridDim, blockDim, 0, stream>>>(
steps,
batch_size,
hidden_size,
workspace,
u,
b,
h,
h + NH,
0.0f,
nullptr);
}
} else {
if (zoneout_prob && zoneout_mask) {
IndrnnFwdOps<T, false, true><<<gridDim, blockDim, 0, stream>>>(
steps,
batch_size,
hidden_size,
workspace,
u,
b,
h,
h + NH,
zoneout_prob,
zoneout_mask);
} else {
IndrnnFwdOps<T, false, false><<<gridDim, blockDim, 0, stream>>>(
steps,
batch_size,
hidden_size,
workspace,
u,
b,
h,
h + NH,
0.0f,
nullptr);
}
}

Expand Down

0 comments on commit 1afeb86

Please sign in to comment.