Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why the magatron_v4.patch is needed? #14

Open
hxdtest opened this issue Nov 18, 2024 · 4 comments
Open

Why the magatron_v4.patch is needed? #14

hxdtest opened this issue Nov 18, 2024 · 4 comments

Comments

@hxdtest
Copy link

hxdtest commented Nov 18, 2024

https://github.com/volcengine/verl/blob/main/patches/megatron_v4.patch

For example:

  • case 1
-    tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
+    tensor_shape = [seq_length, micro_batch_size, hidden_size]

what is the difference between hidden_size and config.hidden_size?

  • case 2
     # Run 1F1B in steady state.
     for i in range(num_microbatches_remaining):
         last_iteration = i == (num_microbatches_remaining - 1)
+        next_forward_k = num_warmup_microbatches + i + 1
+        backward_k = i
 

Why do you need next_forward_k and backward_k ?

  • case 3
-        return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps)
+        return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps, False)

Why False is needed?
And for current apex, it seems that memory_efficient is set as False by default. fused_layer_norm.py

  • case 4
+        self.overlap_param_gather = overlap_param_gather
         if self.overlap_param_gather:
             self.remove_pre_hook_handle = torch.nn.modules.module.register_module_forward_pre_hook(
                 self._make_forward_pre_hook())

Why do you need overlap_param_gather? Does it have side-effects on training?

Many thanks !

@hxdtest hxdtest changed the title Why the magatron_v4.patch is needed Why the magatron_v4.patch is needed? Nov 18, 2024
@PeterSH6
Copy link
Collaborator

Hi @hxdtest , the megatron_v4.patch is necessary for veRL for two main reasons:

  1. In veRL, we didn't initialize Megatron-LM with initialize_megatron, which initializes the global args. We only build the necessary process group by using mpu.initialize_model_parallel. Therefore, we have to delete the usage of get_args(). Case 4 is where we delete the get_args() and overlap_param_gather is set to False by default.
  2. We fix the vpp hanging problem when applying remove padding techniques in model training. Case 2 is used for fixing this

For case 1, config.hidden_size should be equal to hidden_size.
False in case 3 could be removed as the default value is False and there seems to be no way to change its value in v0.4

@hxdtest
Copy link
Author

hxdtest commented Nov 20, 2024

Many thanks for your reply.

@hxdtest
Copy link
Author

hxdtest commented Nov 20, 2024

@PeterSH6
Have you tested verl with model size that's larger than 300B ? For example, have you tested llama3 405B ppo training on verl ?

@PeterSH6
Copy link
Collaborator

@hxdtest , we haven't tested verl on the 405B model.

I think we can try it by using a larger TP size in rollout or implementing pipeline parallelism in vLLM rollout. This is one of our plans.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants