Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: weird stuck while training #6095

Open
1 task done
ericxsun opened this issue Oct 19, 2024 · 8 comments
Open
1 task done

[BUG]: weird stuck while training #6095

ericxsun opened this issue Oct 19, 2024 · 8 comments
Assignees
Labels
bug Something isn't working

Comments

@ericxsun
Copy link
Contributor

Is there an existing issue for this bug?

  • I have searched the existing issues

🐛 Describe the bug

When training a language model with the GeminiPlugin, I encountered an issue where the process got stuck during the forward step. I was saving a checkpoint every 3000 steps, and when it got stuck, I had to kill the process and resume from the latest checkpoint.

The stuck times

start step stuck step total step in each run
225000 271464 46464
180000 226463 46463
135000 181463 46463
90000 136463 46463
45000 91463 46463
0 46465 46465

Is there any idea to find why? Thanks a lot.

Environment

CUDA: 12.1
NCCL: 2.18
Pytorch: 2.1.2
Python: 3.8
Colossalai: 0.4.2

@ericxsun ericxsun added the bug Something isn't working label Oct 19, 2024
@Edenzzzz
Copy link
Contributor

Edenzzzz commented Oct 21, 2024

Can you share any relevant messages and stack trace on stuck or exit?

@ericxsun
Copy link
Contributor Author

Can you share any relevant messages and stack trace on stuck or exit?

I didn’t receive any useful information or logs. All nodes seem to be functioning correctly. The only option I have is to kill the training process and resume it.

When I add more logs, the process gets stuck at the forward step.

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Oct 22, 2024

Could you share the stack trace when you kill by ctrl c and a reproducible script?

@ericxsun
Copy link
Contributor Author

ericxsun commented Nov 4, 2024

Could you share the stack trace when you kill by ctrl c and a reproducible script?

Could it caused by the weird behavior described in #6111 ?

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Nov 4, 2024

You can probably test the behavior of all_gather_object, see if it spawns multiple processes.
What happens with booster.save_optimizer(optimizer, path_optimizer, shard=True, size_per_shard=2048) is that it calls into save_sharded_optimizer, which all_gathers the states . You can try removing some barriers along this call stack and ping other members with your findings (whether it fixes the stuck).

@ericxsun
Copy link
Contributor Author

ericxsun commented Nov 4, 2024

I observed that, following this line:

compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None
, the PID for other ranks starts appearing on rank-0

Furthermore, after reaching this line:

compacted_states = torch.zeros(compacted_size, dtype=dtype, device=device, requires_grad=False)

If device is replaced with torch.device(f"cuda:{torch.cuda.current_device()}"), each rank retains only one PID, just as at the start.

compacted_states = torch.zeros(
    compacted_size,
    dtype=dtype,
    device=torch.device(f"cuda:{torch.cuda.current_device()}"),
    requires_grad=False
) 

And after reaching this line:

dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size], group=zero_group)

the PID for other ranks still starts appearing on each rank.

@ericxsun
Copy link
Contributor Author

ericxsun commented Nov 4, 2024

Hi @ver217,could you take a look? Thanks very much.

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Nov 6, 2024

And after reaching this line:

dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size], group=zero_group)

the PID for other ranks still starts appearing on each rank.

This might just be the default behavior. All gather by definition collects tensor-based objects from other ranks.
https://discuss.pytorch.org/t/distributed-all-gather-object-produces-multiple-additional-processes/164991
For the stuck, please try removing dist.barrier call

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants