Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] 0.9.0 release version got param_gather_handle error with 3d parallel #1292

Open
SeunghyunSEO opened this issue Nov 19, 2024 · 6 comments

Comments

@SeunghyunSEO
Copy link

SeunghyunSEO commented Nov 19, 2024

 5: [rank5]:   File "/workspace/megatron/core/transformer/transformer_block.py", line 493, in forward
 5: [rank5]:     hidden_states, context = layer(
 5: [rank5]:   File "/workspace/megatron/core/transformer/transformer_layer.py", line 426, in __call__
 5: [rank5]:     return super(MegatronModule, self).__call__(*args, **kwargs)
 5: [rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
 5: [rank5]:     return self._call_impl(*args, **kwargs)
 5: [rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
 5: [rank5]:     result = forward_call(*args, **kwargs)
 5: [rank5]:   File "/workspace/megatron/core/transformer/transformer_layer.py", line 369, in forward
 5: [rank5]:     mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
 5: [rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
 5: [rank5]:     return self._call_impl(*args, **kwargs)
 5: [rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
 5: [rank5]:     result = forward_call(*args, **kwargs)
 5: [rank5]:   File "/workspace/megatron/core/transformer/mlp.py", line 132, in forward
 5: [rank5]:     output, output_bias = self.linear_fc2(intermediate_parallel)
 5: [rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
 5: [rank5]:     return self._call_impl(*args, **kwargs)
 5: [rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1777, in _call_impl
 5: [rank5]:     args_result = hook(self, args)
 5: [rank5]:   File "/workspace/megatron/core/distributed/distributed_data_parallel.py", line 334, in hook
 5: [rank5]:     self.param_to_bucket_group[param].finish_param_sync(
 5: [rank5]:   File "/workspace/megatron/core/distributed/param_and_grad_buffer.py", line 221, in finish_param_sync
 5: [rank5]:     self.next_param_gather_bucket_group.start_param_sync()
 5: [rank5]:   File "/workspace/megatron/core/distributed/param_and_grad_buffer.py", line 167, in start_param_sync
 5: [rank5]:     assert self.param_gather_handle is None
 5: [rank5]: AssertionError

it's 4 node experiment where i used distributed_optimizer, overlap param gather and grad all reduce as True and tp=2, pp=4.
idk why next linear fc2's next_param_gather_bucket_group has asnyc param_gather context manager ...?

@SeunghyunSEO SeunghyunSEO changed the title [BUG] 0.9.0 release version got error when 3d parallel is used [BUG] 0.9.0 release version got param_gather_handle error with 3d parallel Nov 21, 2024
@SeunghyunSEO
Copy link
Author

i found it's because chain bucketing order is not matched with forward path

@deepakn94
Copy link
Collaborator

Can you share a small reproduction script?

@SeunghyunSEO
Copy link
Author

SeunghyunSEO commented Nov 22, 2024

Can you share a small reproduction script?

@deepakn94 hi deepak, good to see you!
i cant share the simple reproduction script because error occurs in my custom model.
but i can say that, i tested to add rmsnorm here and there in the residual block like this nvidia's paper for training stability.
for example there can be rmsnorm right after the value layer, self attn output and fc2 output

Image

but in this scenario, i added custom layers in the end of the transformer layer block (after this line)

class TransformerLayer(MegatronModule, BaseTransformerLayer):

    def __init__(
        self,
        config: TransformerConfig,
        submodules: TransformerLayerSubmodules,
        layer_number: int = 1,
        hidden_dropout: float = None,
    ):
...
        # [Module 8: MLP block]
        # TODO how to set the gpt_layer_spec.py when we have moe_frequency > 1,
        #      where MLP and MoE layer both appear alternately?
        self.mlp = build_module(submodules.mlp, config=self.config)
        if hasattr(self.mlp, 'set_layer_number'):
            self.mlp.set_layer_number(self.layer_number)

        # [Module 9: BiasDropoutFusion]
        self.mlp_bda = build_module(submodules.mlp_bda)

        # @jcasper how should we handle nvfuser?
        # Set bias+dropout+add fusion grad_enable execution handler.
        # TORCH_MAJOR = int(torch.__version__.split('.')[0])
        # TORCH_MINOR = int(torch.__version__.split('.')[1])
        # use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
        self.bias_dropout_add_exec_handler = torch.enable_grad

        ## here, custom layers are added in the end of init 
        self.attn_out_rmsnorm = ...
        self.fc2_rmsnorm = ...

but this layers does not forwarded in this order,
the model forward logic will be self_attn_prenorm -> self_attn -> attn_out_rmsnorm -> bda -> mlp_prnorm -> fc1 -> fc2 -> fc2_rmsnorm -> bda.
however, the bucketing order will be self_attn_prenorm -> self_attn -> mlp_prnorm -> fc1 -> fc2 -> attn_out_rmsnorm -> fc2_rmsnorm (this line).
so if finish_param_sync is activated, for example 0th layer's first bucket start to sync (using self.start_param_sync) (16th bucket) and it will call next_param_gather_bucket_group.start_param_sync too (15 bucket).
but because bucket order is not matched with forward path, the attn_out_rmsnorm bucket is 14th bucket not 15th bucket and it call next param sync too, so 13 bucket will be gathered asynchronously too.
however, fc1 layer will be 15 bucket but because this bucket's self.param_gather_dispatched is True, its' fine but it's next_param_gather_bucket_group will be 14th bucket which is all-gathered already, however as you can see below,
there is no logic for checking next_param_gather_bucket_group is gathered or not ( param_gather_dispatched)

        # If current bucket's param AG has not been dispatched, dispatch it now (e.g., first
        # AG bucket in first model chunk if ddp_config.align_param_gather is False).
        if not self.param_gather_dispatched:
            self.start_param_sync()

        if self.param_gather_handle is not None:
            self.param_gather_handle.wait()
            self.param_gather_handle = None
            # Dispatch next bucket's asynchronous param AG.
            if self.next_param_gather_bucket_group is not None and not skip_next_bucket_dispatch:
                self.next_param_gather_bucket_group.start_param_sync()

so i fixed above code snippet like this

        if self.param_gather_handle is not None:
            self.param_gather_handle.wait()
            self.param_gather_handle = None
            # Dispatch next bucket's asynchronous param AG.
            if (
                self.next_param_gather_bucket_group is not None and not skip_next_bucket_dispatch
                ) and (
                    not self.next_param_gather_bucket_group.param_gather_dispatched
                ):
                self.next_param_gather_bucket_group.start_param_sync()

if my explanation lacks information, please reply again or email me, ty!
(and if you like my solution, i can PR)

@fanzhongyi
Copy link

fanzhongyi commented Nov 24, 2024

Thank you for providing this patch! I’ve tested it, and it indeed allows training to proceed. However, I’ve observed an issue with checkpointing: after saving a checkpoint, the loss immediately diverges. This suggests that the checkpointing logic is also affected by the mismatch in parameter declaration and usage order.

As a temporary workaround, I’ve adjusted the parameter declaration order to align with the forward pass usage order.
For reference, I’ve attached the loss curve from my experiment. At around step 200, you can see that an asynchronous checkpointing issue caused the loss to spike and diverge.

Image

Let me know if there’s a more robust solution in progress or if additional details from my setup would help with debugging.

Can you share a small reproduction script?

@deepakn94 hi deepak, good to see you! i cant share the simple reproduction script because error occurs in my custom model. but i can say that, i tested to add rmsnorm here and there in the residual block like this nvidia's paper for training stability. for example there can be rmsnorm right after the value layer, self attn output and fc2 output

Image

but in this scenario, i added custom layers in the end of the transformer layer block (after this line)

class TransformerLayer(MegatronModule, BaseTransformerLayer):

def __init__(
    self,
    config: TransformerConfig,
    submodules: TransformerLayerSubmodules,
    layer_number: int = 1,
    hidden_dropout: float = None,
):

...
# [Module 8: MLP block]
# TODO how to set the gpt_layer_spec.py when we have moe_frequency > 1,
# where MLP and MoE layer both appear alternately?
self.mlp = build_module(submodules.mlp, config=self.config)
if hasattr(self.mlp, 'set_layer_number'):
self.mlp.set_layer_number(self.layer_number)

    # [Module 9: BiasDropoutFusion]
    self.mlp_bda = build_module(submodules.mlp_bda)

    # @jcasper how should we handle nvfuser?
    # Set bias+dropout+add fusion grad_enable execution handler.
    # TORCH_MAJOR = int(torch.__version__.split('.')[0])
    # TORCH_MINOR = int(torch.__version__.split('.')[1])
    # use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
    # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
    self.bias_dropout_add_exec_handler = torch.enable_grad

    ## here, custom layers are added in the end of init 
    self.attn_out_rmsnorm = ...
    self.fc2_rmsnorm = ...

but this layers does not forwarded in this order, the model forward logic will be self_attn_prenorm -> self_attn -> attn_out_rmsnorm -> bda -> mlp_prnorm -> fc1 -> fc2 -> fc2_rmsnorm -> bda. however, the bucketing order will be self_attn_prenorm -> self_attn -> mlp_prnorm -> fc1 -> fc2 -> attn_out_rmsnorm -> fc2_rmsnorm (this line). so if finish_param_sync is activated, for example 0th layer's first bucket start to sync (using self.start_param_sync) (16th bucket) and it will call next_param_gather_bucket_group.start_param_sync too (15 bucket). but because bucket order is not matched with forward path, the attn_out_rmsnorm bucket is 14th bucket not 15th bucket and it call next param sync too, so 13 bucket will be gathered asynchronously too. however, fc1 layer will be 15 bucket but because this bucket's self.param_gather_dispatched is True, its' fine but it's next_param_gather_bucket_group will be 14th bucket which is all-gathered already, however as you can see below, there is no logic for checking next_param_gather_bucket_group is gathered or not ( param_gather_dispatched)

    # If current bucket's param AG has not been dispatched, dispatch it now (e.g., first
    # AG bucket in first model chunk if ddp_config.align_param_gather is False).
    if not self.param_gather_dispatched:
        self.start_param_sync()

    if self.param_gather_handle is not None:
        self.param_gather_handle.wait()
        self.param_gather_handle = None
        # Dispatch next bucket's asynchronous param AG.
        if self.next_param_gather_bucket_group is not None and not skip_next_bucket_dispatch:
            self.next_param_gather_bucket_group.start_param_sync()

so i fixed above code snippet like this

    if self.param_gather_handle is not None:
        self.param_gather_handle.wait()
        self.param_gather_handle = None
        # Dispatch next bucket's asynchronous param AG.
        if (
            self.next_param_gather_bucket_group is not None and not skip_next_bucket_dispatch
            ) and (
                not self.next_param_gather_bucket_group.param_gather_dispatched
            ):
            self.next_param_gather_bucket_group.start_param_sync()

if my explanation lacks information, please reply again or email me, ty! (and if you like my solution, i can PR)

@SeunghyunSEO
Copy link
Author

@fanzhongyi thank you for the test! tbh, i didnt check checkpoint loading. ill test this too. thank you so much :)

@SeunghyunSEO
Copy link
Author

@deepakn94 may i ask your opinion sir

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants