Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: Got nan during backward with zero2 #6091

Open
1 task done
flymin opened this issue Oct 16, 2024 · 9 comments
Open
1 task done

[BUG]: Got nan during backward with zero2 #6091

flymin opened this issue Oct 16, 2024 · 9 comments
Assignees
Labels
bug Something isn't working

Comments

@flymin
Copy link
Contributor

flymin commented Oct 16, 2024

Is there an existing issue for this bug?

  • I have searched the existing issues

🐛 Describe the bug

My code is based on Open-Sora, and can run without any issue on 32 gpus, using zero2.

However, when using 64 gpus, nan appears in the tensor gradients after the second backward step.

I have made a workaround to patch colossalai/zero/low_level/low_level_optim.py with

# line 313, in _run_reduction
flat_grads = bucket_store.get_flatten_grad()
flat_grads /= bucket_store.world_size
if torch.isnan(flat_grads).any():  # here
    raise RuntimeError(f"rank {dist.get_rank()} got nan on flat_grads")  # here
...
        if received_grad.dtype != grad_dtype:
            received_grad = received_grad.to(grad_dtype)
        if torch.isnan(received_grad).any():  # here
            raise RuntimeError(f"rank {dist.get_rank()} got nan on received_grad") # here
...

With the patch above, my code run normally and the loss seems fine.

I think it may related to asynchronized state between cuda streams. I do not the exact reason and I do not think my workaround could really solve the issue.

Any idea from the team member?

Environment

Nvidia H20

ColossalAI version: 0.4.3
cuda 12.4
pytorch 2.4

@flymin flymin added the bug Something isn't working label Oct 16, 2024
@Edenzzzz
Copy link
Contributor

Edenzzzz commented Oct 26, 2024

FP16 ZeRO should auto-check for overflow and skip that step, though this seems unimplemented for bf16. @botbw would you be able to take a look? I haven't been maintaining this part.

@Edenzzzz Edenzzzz assigned botbw, wangbluo and ver217 and unassigned wangbluo Oct 26, 2024
@Edenzzzz
Copy link
Contributor

Since you're using Open-Sora, feel free to open this kind of issue in their repo too.

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Oct 26, 2024

Indeed bf16 has the same range as fp32, but in my opinion this check can be enforced on all precisions?

@flymin
Copy link
Contributor Author

flymin commented Oct 26, 2024

Hi @Edenzzzz, thank you for involving.

Please note that by adding the mentioned lines, nan will not occur again and the RuntimeError are never raised by these lines. Therefore, I don’t think skipping a specific iteration could help. I suspect the bug is from communication. That’s also the reason I open an issue in this repo.

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Oct 26, 2024

Emm then I think isnan triggers synchronization? You should check whether nan is in received_grad or flat_grads. You can also try removing those lines and put torch.cuda.synchronize().
dist.reduce_scatter is synchronous here and received_grad is init to zero, not nan, so comm might not be the issue.

@botbw
Copy link
Contributor

botbw commented Oct 27, 2024

@flymin Thanks for reporting this! Will it be possible to share the config/code snippet you are using? If not, could you try setting overlap_communication=False in LowLevelZeroPlugin and check if the problem still exists? After which we can conclude there might be bugs with communication stream synchronization.

@flymin
Copy link
Contributor Author

flymin commented Oct 27, 2024

Sorry I cannot provide my current code. I may have some time late in the Nov. to work on this issue again.

I have tried adding synchronization code in this function but it does not help. I also tried overlap_communication=False, the issue still exists.

From my workaround, I have to add two if blocks. Either single block cannot help.

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Oct 27, 2024

If you suspect comm issues, you can put dist.barrier.
You should also try printing or inserting dist.breakpoint() to find out the first variable that becomes nan.
I've also used isnan for pipeline parallel debugging, but didn't see it doing any special synchronization... In my case grad was not received somewhere, and was further propagated across layers to create nan

@flymin
Copy link
Contributor Author

flymin commented Oct 29, 2024

barrier does not help. I will try dist.breakpoint() later. Thank you for your advice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants