Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PGLE doesn't work for Tensor Parallelism #1005

Open
wang2yn84 opened this issue Nov 1, 2024 · 3 comments
Open

PGLE doesn't work for Tensor Parallelism #1005

wang2yn84 opened this issue Nov 1, 2024 · 3 comments

Comments

@wang2yn84
Copy link
Collaborator

wang2yn84 commented Nov 1, 2024

We observed good overlap with FSDP + PGLE:
Bq7PCuqyJbygSuL. Turning on and off PGLE makes a big difference here.

However, with TP + PGLE:
7nGeZQwG5Un84P3

There is no performance improvements. Computation and communications are completely exposed.

Here is the command:
switch to lance-405b-clean branch

python3
MaxText/train.py MaxText/configs/models/gpu/llama3.1_405b.yml hardware=gpu
run_name=maxtext-llama3.1-405b steps=10 max_target_length=4096 model_name=llama3.1-405b
enable_checkpointing=false attention=cudnn_flash_te dataset_type=synthetic
async_checkpointing=false base_output_directory=gs://lancewang-dev-supercomputer-testing/maxtext_gpu
logits_dot_in_fp32=false use_iota_embed=true ici_tensor_parallelism=8 dcn_fsdp_parallelism=32
dcn_pipeline_parallelism=1 per_device_batch_size=1 num_layers_per_pipeline_stage=16 weight_dtype=bfloat16
remat_policy=save_qkv_proj profiler=xplane skip_first_n_steps_for_profiler=5
base_num_decoder_layers=126

Here are the xla flags:
--xla_gpu_graph_level=0 --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false
--xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=536870912
--xla_gpu_all_gather_combine_threshold_bytes=536870912 --xla_gpu_reduce_scatter_combine_threshold_bytes=536870912
--xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true
--xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true
--xla_disable_hlo_passes=rematerialization --xla_gpu_enable_pgle_accuracy_checker=false
--xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false
--xla_gpu_enable_reduce_scatter_combine_by_dim=false

Here are the env variable:
NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE=/usr/local/nvidia/lib64/a3plus_guest_config.textproto
NCCL_FASTRAK_PLUGIN_ACCEPT_TIMEOUT_MS=600000
JAX_ENABLE_PGLE=true
JAX_REMOVE_CUSTOM_PARTITIONING_PTR_FROM_CACHE_KEY=true
JAX_DEBUG_LOG_MODULES=compiler

The image we built on Oct 22nd.

@reedwm
Copy link
Contributor

reedwm commented Nov 1, 2024

@Tixxx do you know what the issue is? I'm trying to reproduce this issue myself still.

@Tixxx
Copy link

Tixxx commented Nov 7, 2024

I cannot access the screenshot above, it says page not found. Just a preliminary guess, the combiner threshold might introduce more data dependencies, so we usually tune it down if the collective is a combined one with a lot of data dependencies.

I have tried reproing using your command on maxtext main, but the yaml file doesnt exist for me. Would you be able to share a smaller model that can be easily repro'd on a single node? Thanks

@wang2yn84
Copy link
Collaborator Author

wang2yn84 commented Nov 7, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants