Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about VNMaxPool #11

Open
guochengqian opened this issue Feb 19, 2022 · 4 comments
Open

Questions about VNMaxPool #11

guochengqian opened this issue Feb 19, 2022 · 4 comments

Comments

@guochengqian
Copy link

guochengqian commented Feb 19, 2022

Dear @FlyingGiraffe,

Thanks for your impressive work.

I have a question about the VNMaxPool layer. You used the argmax to select the idx of the pooled features. However, the idx is not differentiable, and thus the self.map_to_dir layer will not be updated anyway.

d = self.map_to_dir(x.transpose(1,-1)).transpose(1,-1)

Could you explain this a little bit? Maybe I am wrong.

Thank you very much!

@Julien-Gustin
Copy link

Hi @guochengqian,
I have the same problem, have you found how to solve this problem?

Thank you. (cc. @FlyingGiraffe )

@guochengqian
Copy link
Author

@Julien-Gustin Hi thanks for ping. No I quit on this problem.

@chrockey
Copy link

Hi, @FlyingGiraffe. Thanks for open-sourcing your codes.
May I ask you to answer this issue?

@udaykamal20
Copy link

I tried using softargmax but still ended up with the same error (in DDP mode); therefore finally went for the softmaxpool instead of hardmaxpool. This ensures differentiability and update to the self.map_to_dir.

class VNMaxPool(nn.Module):

    def __init__(self, in_channels):
        super(VNMaxPool, self).__init__()
        self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False)

    def forward(self, x):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        d = self.map_to_dir(x.transpose(1,-1)).transpose(1,-1)
        dotprod = (x*d).sum(2, keepdims=True)
        sm = torch.nn.functional.softmax(dotprod*100., dim=-1) #100 or any large value to scale the input so that the max value dominates over others
        x_max = (x*sm).sum(-1)
        return x_max

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants