Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError: input should be in float32 type, got torch.float16 #9

Open
lawrence-ff opened this issue Mar 5, 2024 · 5 comments
Open

Comments

@lawrence-ff
Copy link

微信图片_20240202181417
Hi, this problem is caused by the assertions in the norm.py file, it shows that some parts require the input data type to be torch.float32 but actually the input data type is torch.float16, what is causing this problem, is the interpreter not working,?
looking forward to your reply, thanks!

@georghess
Copy link
Collaborator

Hi,

I think there is some bug in the mixed precision training. A work-around is to change the forward method in the VFELayer from

    @auto_fp16(apply_to=('inputs'), out_fp32=True)
    def forward(self, inputs):
        """Forward function.

        Args:
            inputs (torch.Tensor): Voxels features of shape (M, C).
                M is the number of points, C is the number of channels of point features.

        Returns:
            torch.Tensor: point features in shape (M, C).
        """
        # [K, T, 7] tensordot [7, units] = [K, T, units]
        x = self.linear(inputs)
        x = self.norm(x)
        pointwise = F.relu(x)
        return pointwise

to

    @auto_fp16(apply_to=('inputs'), out_fp32=False)#
    def forward(self, inputs):
        """Forward function.

        Args:
            inputs (torch.Tensor): Voxels features of shape (M, C).
                M is the number of points, C is the number of channels of point features.

        Returns:
            torch.Tensor: point features in shape (M, C).
        """
        # [K, T, 7] tensordot [7, units] = [K, T, units]
        x = self.linear(inputs)
        x = self.norm(x.float())
        pointwise = F.relu(x)
        return pointwise

@gorkemguzeler
Copy link

Hi @georghess , I reproduced the same issue, and the proposed solution did not resolve the problem. I would appreciate any other ideas?

@georghess
Copy link
Collaborator

Hi @gorkemguzeler, pretty sure that the solution above should work. That's what we've done on our dev-branch at least. There we've switched to our own fork of SST which has the above changes.

To help you more than this, I'd need some more info. Could you send the entire error trace?

@gorkemguzeler
Copy link

gorkemguzeler commented Sep 10, 2024

Hi @georghess, thanks a lot for your quick reply and information!

I switched my branch to the dev and used your own fork of SST. I did not run into the above issue this time.

One thing I noticed is:
When I started the training on the main branch after some updates in norm.py, scatter_points.py files to avoid the issue above, training the model took around 1.5 hour per epoch. Additionally, i was getting the following warning:

No voxel belongs to drop_level3 in shift 0
No voxel belongs to drop_level4 in shift 0

I just tried with the dev branch and forked SST, there is no such warnings during training but it takes 11 hours per epoch.

  • Are the warnings above expected during the training on main branch? Should I resolve them for stable training?
  • What could be the reason of huge training time duration difference between these trainings (e.g. a default hyperparameter impact, dataset specific method) given that I use the same GPU and dataset (nuscenes trainval split)? I also checked out the lidar encoders and both branches use the sst.
  • Which branch would you recommend me to use given that I will try to implement some downstream tasks on top of your repository?

@daiduck
Copy link

daiduck commented Nov 14, 2024

I also meet this error when I use mmdet3d , the situation I encountered is, the Decorator @auto_fp16 controls precision conversion based on the parameter fp16_enabled. in some cases, this parameter is set to False , so the output is not correctly converted to fp32 . thats what i meet .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants