Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IndRNNs #7

Open
OverLordGoldDragon opened this issue Mar 16, 2020 · 10 comments
Open

IndRNNs #7

OverLordGoldDragon opened this issue Mar 16, 2020 · 10 comments

Comments

@OverLordGoldDragon
Copy link

I'd like to suggest support for IndRNNs; in my experiments on EEG seizure classification w/ very long sequences, they've dominated LSTMs & GRUs. While already also much faster, IndRNNs would benefit from a CuDNN-like speedup in large stacks, and from Layer Normalization for working w/ 1000+ timesteps.

Minimal tf.keras code below; default weight initialization should be handled differently - can clarify post-approval.


IndRNN Cell
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import math_ops
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.util.tf_export import keras_export
from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin



@keras_export(v1=['keras.layers.IndRNNCell'])
class IndRNNCell(DropoutRNNCellMixin, Layer):
  def __init__(self,
               units,
               activation='tanh',
               use_bias=True,
               recurrent_clip_min=-1,
               recurrent_clip_max=-1,
               kernel_initializer='glorot_normal',
               recurrent_initializer=None,
               bias_initializer='zeros',
               kernel_regularizer=None,
               recurrent_regularizer=None,
               bias_regularizer=None,
               kernel_constraint=None,
               recurrent_constraint=None,
               bias_constraint=None,
               dropout=0.,
               recurrent_dropout=0.,
               implementation=1,
               **kwargs):
    super(IndRNNCell, self).__init__(**kwargs)
    
    if recurrent_clip_min is None or recurrent_clip_max is None:
        recurrent_clip_min = None
        recurrent_clip_max = None

    self.units = units
    self.activation = activations.get(activation)
    self.use_bias = use_bias
    self.recurrent_clip_min = recurrent_clip_min
    self.recurrent_clip_max = recurrent_clip_max

    self.kernel_initializer = initializers.get(kernel_initializer)
    if self.recurrent_initializer is None:
        self.recurrent_initializer = initializers.uniform(-1.0, 1.0)
    else:
        self.recurrent_initializer = initializers.get(recurrent_initializer)
    self.bias_initializer = initializers.get(bias_initializer)

    self.kernel_regularizer = regularizers.get(kernel_regularizer)
    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
    self.bias_regularizer = regularizers.get(bias_regularizer)

    self.kernel_constraint = constraints.get(kernel_constraint)
    self.recurrent_constraint = constraints.get(recurrent_constraint)
    self.bias_constraint = constraints.get(bias_constraint)

    self.dropout = min(1., max(0., dropout))
    self.recurrent_dropout = min(1., max(0., recurrent_dropout))

    self.state_size = data_structures.NoDependency([self.units])
    self.output_size = self.units

  @tf_utils.shape_type_conversion
  def build(self, input_shape):
    input_dim = input_shape[-1]
    self.timesteps = input_shape[1]
    self._process_recurrent_clip()

    self.kernel = self.add_weight(
        shape=(input_dim, self.units),
        name='kernel',
        initializer=self.kernel_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint)
    self.recurrent_kernel = self.add_weight(
        shape=(self.units,),
        name='recurrent_kernel',
        initializer=self.recurrent_initializer,
        regularizer=self.recurrent_regularizer,
        constraint=self.recurrent_constraint)

    if self.use_bias:
      self.bias = self.add_weight(
          shape=(self.units,),
          name='bias',
          initializer=self.bias_initializer,
          regularizer=self.bias_regularizer,
          constraint=self.bias_constraint)
    else:
      self.bias = None    
    self.built = True


  def call(self, inputs, states, training=None):
    h_tm1 = states[0]  # previous memory state

    dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=1)
    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
        h_tm1, training, count=1)

    if 0. < self.dropout < 1.:
      inputs = inputs * dp_mask[0]
    if 0. < self.recurrent_dropout < 1.:
      h_tm1 = h_tm1 * rec_dp_mask[0]
    
    h = K.dot(inputs, self.kernel)
    h += math_ops.multiply(h_tm1, self.recurrent_kernel)
    if self.use_bias:
      h = K.bias_add(h, self.bias)

    h = self.activation(h)
    return h, [h]
@sharvil
Copy link
Contributor

sharvil commented Mar 17, 2020

Sounds like a reasonable addition. The IndRNN paper doesn't mention layer norm – do you have an algorithm that's known to produce good results? I'm wondering if we can get away with applying layer norm only on the input like this:

How important are cell clipping, input dropout, recurrent dropout, and non-tanh activation functions in practice? And what weight initialization scheme do you propose?

@OverLordGoldDragon
Copy link
Author

OverLordGoldDragon commented Mar 18, 2020

@sharvil Wonderful.


WEIGHT INITIALIZATION: author repo

kernel: authors recommend Normal w/ small std. In my application, 'glorot_normal' and 'glorot_uniform' have yielded comparable results (didn't try others); note that normal is truncated (I suggest all default Normals are truncated to avoid unlikely but possible extreme weight values). I'd default to 'glorot_normal', but no strong inclination.

recurrent_kernel: authors recommend a sophisticated initialization scheme w/ timesteps-based clipping. Several points:

  1. timesteps-based clipping will be a bit involved to implement in Keras API-friendly manner.
  2. I'm not very convinced by the need for such elaborate clipping; per this graph I made based on authors' excerpt, the difference between clipped and simply [-1, 1] weights is quite small for long sequences, and authors themselves note clipping to be redundant for short (<20 timesteps) sequences. More importantly, being around 1 at all may be harmful (see below).
  3. Most implementations default to uniform [-1, 1]; I recommend against this. In my application, [-.2, .2] has worked best, and [-1, 1] was a wreck; my explanation is, large weights yield large pre-activations, driving tanh into saturation (and relu into explosion), harming backprop for long sequences. With my scheme, I achieved near picture-perfect gradients for 160+ timesteps.
  4. My recommendation is: [-.5, .5] (uniform), with a docstring mention on difference w/ authors. IndRNNs are likelier to be used for long sequence tasks, where [-1, 1] can be a bad default. Caveats:
  • My experiments are limited to signal classification; results may differ in other domains
  • If defaulting to [-1, 1], it's worth mentioning trying smaller bounds for longer sequences in a docstring
  1. Whatever the default, I suggest haste provide a more convenient way to initialize via uniform or (truncated) normal. TF/Keras requires an import; instead, we can take a dict like {'uniform': 1} to mean RandomUniform(-1, 1) or {'normal': .01} to mean TruncatedNormal(stdev=.01).

bias: use_bias=True, initialize to zeros. Same as authors'.


ACTIVATION:

'relu' was a (bad) bomb in my application; 'selu' was stabler - but this is rather inherent to long sequences. Authors' success may be domain-specific; for a safer general default, and what proved superior in my case, I recommend 'tanh'.


LAYER NORMALIZATION:

The benefits of LayerNorm or BatchNorm, especially implemented recurrently for RNNs, are basically universal (some interesting reading here) - and will be even more pronounced for long sequence tasks with typical vanishing gradients. For IndRNNs, it remains important to normalize both input-to-input and hidden-to-hidden transforms, and separately; the authors of recurrent batch normalization go so far as to normalize gates separately, with sound arguments. The idea is, information flow dynamics are unique to each transform, and in IndRNN's recurrent_kernel is additionally distinctly a vector.

Btw, Recurrent BN may prove superior to Recurrent LN, though is harder to implement - but that's for another time.


DROPOUTS: should behave same as TF/Keras's SimpleRNN. Though, recurrent_dropout as-is is problematic (for all RNNs) - I'll clarify another time; can mimic TF for starters.

sharvil added a commit that referenced this issue Mar 19, 2020
This implementation doesn't support Zoneout regularization or Layer
Normalization. The weight initialization scheme is primitive, and
the RNN layer doesn't support activation functions other than `tanh`.
Still, this change provides a good starting point and the required
scaffolding for a more sophisticated implementation.

Issue: #7
sharvil added a commit that referenced this issue Mar 25, 2020
sharvil added a commit that referenced this issue Mar 26, 2020
sharvil added a commit that referenced this issue Mar 26, 2020
sharvil added a commit that referenced this issue Mar 26, 2020
See discussion in #7. In
short, default to random uniform [-0.5, 0.5].

Issue: #7
@sharvil
Copy link
Contributor

sharvil commented Mar 26, 2020

Thanks for the detailed writeup. I'm following your recommendations for the most part but I don't think dropout is a good addition. Specifically, recurrent dropout is known to be A Bad Idea™ for RNNs as you pointed out and dropout between layers can be implemented trivially by the caller – it doesn't need to be implemented by the Haste layer.

@OverLordGoldDragon
Copy link
Author

@sharvil I see your commit history - good progress so far, though I hope LayerNorm support is planned as it can make IndRNNs vastly more powerful for very long sequences. Regarding recurrent dropout, I disagree - it can be a very effective regularizer if used properly, though I wouldn't follow TensorFlow's implementation for it - I'll dedicate an Issue to this sometime.

@sharvil
Copy link
Contributor

sharvil commented Apr 1, 2020

Yes, Layer Norm support is planned, though no specific ETA yet.

sharvil added a commit that referenced this issue Apr 1, 2020
This change consists of two optimizations:
  1. Perform all pointwise operations in a single CUDA kernel
     (even across time steps).
  2. Reduce memory accesses in the backward pass by accumulating
     sums in registers and writing back to global memory at the
     end of the kernel.

The result is a ~20% speed improvement for a single training
iteration that consists of a forward and backward pass.

Issue: #7
sharvil added a commit that referenced this issue May 4, 2020
This patch adds layer normalization to the input activations but
not the recurrent activations. In a future patch, we'll support
normalizing recurrent activations as well.

Issue: #7
@amurashov
Copy link

I have tested IndRNN, do I understand correctly it is not productional yet?

Training fails as the gradient of indrnn_cell_36/recurrent_scale:0 is of wrong shape:

grad shape: (512, 512) (but all but first row are zeros)
weight shape: (512,)

@sharvil
Copy link
Contributor

sharvil commented Jun 1, 2020

@amurashov thanks for the report. Looks like a bug in the tensor shape in my TensorFlow integration, which I've now fixed. IndRNN is at a functional stage, LayerNormIndRNN is usable but only performs layer norm on the inputs, not on the recurrent connection.

@amurashov
Copy link

Yes, appears fixed. Thanks!

@bratao
Copy link

bratao commented Jul 1, 2020

@sharvil Thank you for your awesome work. This is pure gold!!
Can we consider IndRNN as production ready?

@sharvil
Copy link
Contributor

sharvil commented Jul 1, 2020

@bratao, IndRNN is complete and ready for prime-time use. I've kept this issue open to track the full LayerNormIndRNN implementation. That still needs some work which I haven't gotten around to yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants