[QUESTION] Training Mixtral 8x7B on 16 x H100 only achieves low throughput of 130 TFLOPS #1178
Replies: 28 comments
-
Thank you for reporting this issue. 130 TFLOPS is indeed too low for the H100.
|
Beta Was this translation helpful? Give feedback.
-
Hi, thanks for the suggestions.
The throughput has indeed increased significantly, reaching around 230 TFLOP/s. Here is the logs
|
Beta Was this translation helpful? Give feedback.
-
If expert_parallel_size==num_moe_experts, the num_local_experts is 1 and GroupedMLP is same as SequentialMLP, is it right? And as I know, the communication overhead of pp is less than tp and ep if the proportion of bubble time is not too high, is MoE support pp and make it more efficient? |
Beta Was this translation helpful? Give feedback.
-
Apologies for the delayed reply. 230 TFLOPS falls below our expectations; Currently, we can exceed 330TFLOPS on the H100 and potentially higher by switching to EP8TP1 with re-computation. |
Beta Was this translation helpful? Give feedback.
-
Does that mean you can achieve over 330 TFLOPS in the same or similar software environment and settings? |
Beta Was this translation helpful? Give feedback.
-
Hi @ShinoharaHare , our env is:
I double-checked your scripts and suggest the following modifications:
Let's see how performance changes after these changes ^ ^. |
Beta Was this translation helpful? Give feedback.
-
Hi XLZed, MCore MoE does support PP, but for the Mixtral 8x7B model, we prefer EP and TP. |
Beta Was this translation helpful? Give feedback.
-
Does grouped_gemm support variable token lengths to local experts on the same rank? |
Beta Was this translation helpful? Give feedback.
-
Yes, we support variable lengths for inputs from each local expert. |
Beta Was this translation helpful? Give feedback.
-
@yanring |
Beta Was this translation helpful? Give feedback.
-
which modification brings the most speed improvement? |
Beta Was this translation helpful? Give feedback.
-
@ShinoharaHare Could you please share your checkpoint conversion script? |
Beta Was this translation helpful? Give feedback.
-
The most significant performance change is achieved by resuming from a trained checkpoint. If you do not have pretrained weights, you can train from scratch for about 500 steps. We noticed that after several hundred steps, the token distribution will become quite balanced. |
Beta Was this translation helpful? Give feedback.
-
@yanring , @ShinoharaHare , can you please share a conversion script for Mixtral from HF weights ? |
Beta Was this translation helpful? Give feedback.
-
Hi Vlad, we are working on the converter; it is already in the review process. |
Beta Was this translation helpful? Give feedback.
-
Hi, I'm in a similar situation to this issue. But we also have some differences. For example, we use 8 h800, 64 experts, ep=8, tp=1, pp=1. I also encountered some training efficiency issues, but they were not a top priority. What bothers me now is that after I used ep8 and grouped-gemm, my model structure changed. But the inference result is incorrect. I want to know if Megatron-LM will develop a model convert tool that can facilitate me to merge the ep=8 model into the ep=1 model. Or could you provide some information on how to merge a grouped-gemm enabled model? |
Beta Was this translation helpful? Give feedback.
-
Hello @hwdef , thank you for the update. Currently, the format for the weights in GroupedGEMM for each expert is [input_size, output_size], which is different from the format used in SequentialMLP's ParallelLinear, [output_size, input_size]. Did you transpose the weight during your conversion? @cb521 can help to take a look if this issue continues. By the way, we are also working on supporting distributed checkpointing with Grouped GEMM. |
Beta Was this translation helpful? Give feedback.
-
Yes, we have considered the order of output_size and input_size |
Beta Was this translation helpful? Give feedback.
-
@yanring |
Beta Was this translation helpful? Give feedback.
-
I’m excited about this. When do you plan to merge it into the main branch? |
Beta Was this translation helpful? Give feedback.
-
@hwdef 你好,我也遇到同样的问题,请问现在有解决方法了吗? |
Beta Was this translation helpful? Give feedback.
-
没有 |
Beta Was this translation helpful? Give feedback.
-
Hi @yanring, may I ask why you do not prefer PP? I think the PP-only setting should be the fastest one under multi-node training
|
Beta Was this translation helpful? Give feedback.
-
Yeah, in multi-node training, using PP outperforms EP. But within a node, EP is still the go-to. For instance, the best setup for Mixral 8x7B on 64xH100 is TP1EP8PP8, which hits over 400 TFLOPs. However, if there's only 16xH100, using PP2 might lead to an OOM error, so you'd have to use TP2 instead. |
Beta Was this translation helpful? Give feedback.
-
Sorry for the late response. We have added the converter along with the document on https://github.com/NVIDIA/Megatron-LM/tree/main/examples/mixtral |
Beta Was this translation helpful? Give feedback.
-
Hi @yanring , I got a question regarding the current grouped GEMM experts implementation. Did you mean saving model checkpoints by "checkpointing" or gradient checkpointing? |
Beta Was this translation helpful? Give feedback.
-
Paste a slide on distributed checkpointing (distckpt). We've also added distckpt support for GroupedMLP, enabling you to load a sequentialMLP checkpoint with GroupedMLP enabled. |
Beta Was this translation helpful? Give feedback.
-
Marking as stale. No activity in 60 days. |
Beta Was this translation helpful? Give feedback.
-
As the title says, I wonder if this is normal.
If not, how should I optimize it?
Logs
Beta Was this translation helpful? Give feedback.
All reactions