Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Network's performance decreases after adopting WS #20

Open
NoOneUST opened this issue Mar 5, 2020 · 5 comments
Open

Network's performance decreases after adopting WS #20

NoOneUST opened this issue Mar 5, 2020 · 5 comments

Comments

@NoOneUST
Copy link

NoOneUST commented Mar 5, 2020

I am training a Instance Segmentation network, before I adopt WS, I can achieve mAP 35.66 with Conv+GN, however after adopting WS, I can only achieve 35.27. Is there something wrong with my code? My code to convert the original network to WS is below, note that my original code contains a ResNet101-FPN backbone with deformable convs and depth-separable convs and linear bottlenecks introduced in MobileNet-V2

class Conv2d(nn.Conv2d):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                 padding, dilation, groups, bias)


    def forward(self, x):
        weight = self.weight
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                  keepdim=True).mean(dim=3, keepdim=True)
        weight = weight - weight_mean
        std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
        weight = weight / std.expand_as(weight)
        return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

def convertConv2WeightStand(module, nextChild=None):
    mod = module
    norm_list = [torch.nn.modules.batchnorm.BatchNorm1d, torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.BatchNorm3d, torch.nn.GroupNorm, torch.nn.LayerNorm]
    conv_list = [torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d]
    for norm in norm_list:
        for conv in conv_list:
            if isinstance(mod, conv) and isinstance(nextChild, norm):
                mod = Conv2d(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride,
                 mod.padding, mod.dilation, mod.groups, mod.bias!=None)

    moduleChildList = list(module.named_children())
    for index, [name, child] in enumerate(moduleChildList):
        nextChild = None
        if index < len(moduleChildList) -1:
            nextChild = moduleChildList[index+1][1]
        mod.add_module(name, convertConv2WeightStand(child, nextChild))

    return mod

if cfg.useWeightStandardization:
    net = convertConv2WeightStand(net)
@joe-siyuan-qiao
Copy link
Owner

Thanks for the question. Did you also use the backbones pre-trained with WS? Also, make sure every WS-Conv2d is followed by an activation normalization layer; otherwise, use a regular Conv2d.

@NoOneUST
Copy link
Author

NoOneUST commented Mar 7, 2020

Thanks for the question. Did you also use the backbones pre-trained with WS? Also, make sure every WS-Conv2d is followed by an activation normalization layer; otherwise, use a regular Conv2d.

Thanks for your reply. I tried both

  1. replace conv+BN in backbone+FPN with WS+BN then fine tune
  2. not replace conv+BN in backbone+FPN

The others network components are replaced with WS. In both situations, I saw performance decreases. I have verified the network's architecture, only conv directly followed by BN are replaced by WS. Here I have a doubt, for combined convs like LinearBottleNeck followed byBN, i.e. 3x3+1x1+3x3+BN, should we replace only the last 3x3 conv with WS or all the three convs? In my code, I choose the former one.

@joe-siyuan-qiao
Copy link
Owner

Sorry, it's hard for me to see where the problem might be given the details you provided. However, one thing I would recommend trying is removing weight /= std, i.e. only centering the weights. This would remove the benefits of std but would have more tolerance for different architecture designs. This strategy might also apply to the combined convolutions.

@gautamsreekumar
Copy link

@joe-siyuan-qiao Why is it important that WS has to be followed by a normalization layer? From what I understood, WS aims to preserve the statistics of the tensors. So even for layers without normalization, shouldn't it be useful?

In other words, WS can pass the statistical similarities from
the input channels to the output channels, all the way
from the image space where RGB channels are properly
normalized.

@hiyyg
Copy link

hiyyg commented Nov 7, 2021

Just for your reference, on my task, GN > GN+WC (weight centralization) >> GN+WS.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants