Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

It just explodes!!! #23

Open
VCasecnikovs opened this issue Jul 23, 2020 · 3 comments
Open

It just explodes!!! #23

VCasecnikovs opened this issue Jul 23, 2020 · 3 comments

Comments

@VCasecnikovs
Copy link

Hello, I've been testing WS on my dataset and on my network. I have read about std error. But even after using
std = (torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1) + 1e-5). I found out that there is a problem with input exploding. When I use basic Conv2D block this problem does not exist. So, my question is. Is it possible to somehow to figure it out?

@MohitLamba94
Copy link

Adding affine transformation as in https://github.com/open-mmlab/mmcv/blob/d5cbf7eed1269095bfba1a07913efbbc99d2d10b/mmcv/cnn/bricks/conv_ws.py#L54 write after standardization might help. I have not tried myself but it was originally used by Google Research to avoid NaNs and might work for you as well. Let me know if it helps.

@csvance
Copy link

csvance commented Jan 15, 2024

It explodes even in forward pass because the activation values tend to be much larger when using weight standardization. This is because the weights are normalized to std=1 instead of something like gain / sqrt(fan).

Try the following for forward pass changing the gain if you use a different activation function than ReLU:

    def forward(self, x):
        weight = F.batch_norm(
            self.weight.reshape(1, self.out_channels, -1), None, None,
            training=True, momentum=0., eps=self.eps).reshape_as(self.weight)

        gain = nn.init.calculate_gain('relu')
        fan = nn.init._calculate_correct_fan(self.weight, 'fan_out')
        std = gain / fan**0.5
        weight = std*weight
        x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
        return x

This way the weights are standardized similar to He normal initialization. You still get all the benefit from weight standardization without any kind of activation scale explosion.

@csvance
Copy link

csvance commented Jan 15, 2024

Just to summarize my thoughts overall using this method:

  1. People are used to training nets with He initialization but weight standardization activation scales are quite higher than those initialized with He normal. This can lead to NaN values during forward pass especially when using mixed precision. This can be solved in practice by scaling the weights according to popular normal initialization schemes like He.
  2. Common practice to initialize ResNet is set the last layer in block weights to zero to avoid exploding gradient at initialization. This is mentioned in the original ResNet paper. But with weight standardization, this is not really practical, because even if you set to zero, the weights after the first step will simply be in the same direction as the gradient in the first step. This means that every block will be full ungated at initialization so to speak, which causes exploding gradient and training instability, even if convergence appears faster at first. Aggressive gradient clipping is needed in my experience to prevent weight distributions from distorting, especially when using an adaptive optimizer like Adam.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants