-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IndRNNs #7
Comments
Sounds like a reasonable addition. The IndRNN paper doesn't mention layer norm – do you have an algorithm that's known to produce good results? I'm wondering if we can get away with applying layer norm only on the input like this: How important are cell clipping, input dropout, recurrent dropout, and non- |
@sharvil Wonderful. WEIGHT INITIALIZATION: author repo
ACTIVATION:
LAYER NORMALIZATION: The benefits of LayerNorm or BatchNorm, especially implemented recurrently for RNNs, are basically universal (some interesting reading here) - and will be even more pronounced for long sequence tasks with typical vanishing gradients. For IndRNNs, it remains important to normalize both input-to-input and hidden-to-hidden transforms, and separately; the authors of recurrent batch normalization go so far as to normalize gates separately, with sound arguments. The idea is, information flow dynamics are unique to each transform, and in IndRNN's Btw, Recurrent BN may prove superior to Recurrent LN, though is harder to implement - but that's for another time. DROPOUTS: should behave same as TF/Keras's SimpleRNN. Though, |
This implementation doesn't support Zoneout regularization or Layer Normalization. The weight initialization scheme is primitive, and the RNN layer doesn't support activation functions other than `tanh`. Still, this change provides a good starting point and the required scaffolding for a more sophisticated implementation. Issue: #7
Thanks for the detailed writeup. I'm following your recommendations for the most part but I don't think dropout is a good addition. Specifically, recurrent dropout is known to be A Bad Idea™ for RNNs as you pointed out and dropout between layers can be implemented trivially by the caller – it doesn't need to be implemented by the Haste layer. |
@sharvil I see your commit history - good progress so far, though I hope LayerNorm support is planned as it can make IndRNNs vastly more powerful for very long sequences. Regarding recurrent dropout, I disagree - it can be a very effective regularizer if used properly, though I wouldn't follow TensorFlow's implementation for it - I'll dedicate an Issue to this sometime. |
Yes, Layer Norm support is planned, though no specific ETA yet. |
This change consists of two optimizations: 1. Perform all pointwise operations in a single CUDA kernel (even across time steps). 2. Reduce memory accesses in the backward pass by accumulating sums in registers and writing back to global memory at the end of the kernel. The result is a ~20% speed improvement for a single training iteration that consists of a forward and backward pass. Issue: #7
This patch adds layer normalization to the input activations but not the recurrent activations. In a future patch, we'll support normalizing recurrent activations as well. Issue: #7
I have tested IndRNN, do I understand correctly it is not productional yet? Training fails as the gradient of indrnn_cell_36/recurrent_scale:0 is of wrong shape: grad shape: (512, 512) (but all but first row are zeros) |
@amurashov thanks for the report. Looks like a bug in the tensor shape in my TensorFlow integration, which I've now fixed. IndRNN is at a functional stage, LayerNormIndRNN is usable but only performs layer norm on the inputs, not on the recurrent connection. |
Yes, appears fixed. Thanks! |
@sharvil Thank you for your awesome work. This is pure gold!! |
@bratao, |
I'd like to suggest support for IndRNNs; in my experiments on EEG seizure classification w/ very long sequences, they've dominated LSTMs & GRUs. While already also much faster, IndRNNs would benefit from a CuDNN-like speedup in large stacks, and from Layer Normalization for working w/ 1000+ timesteps.
Minimal
tf.keras
code below; default weight initialization should be handled differently - can clarify post-approval.IndRNN Cell
The text was updated successfully, but these errors were encountered: