Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fsdp::all_gather_copy_in not currently implemented for the XPU device #1328

Open
saforem2 opened this issue Jan 28, 2025 · 5 comments
Open
Assignees
Milestone

Comments

@saforem2
Copy link

🚀 The feature, motivation and pitch

FSDP All Gather Copy not Implemented on XPU Device

Overview

I'm working on trying to run the full_finetune_distributed recipe from pytorch / torchtune I receive the following NotImplementedError:

[rank0]: NotImplementedError: The operator 'fsdp::all_gather_copy_in' is not currently implemented for the XPU device. Please open a feature on https://github.com/intel/torch-xpu-ops/issues. You can set the environment variable `PYTORCH_ENABLE_XPU_FALLBACK=1` to use the CPU implementation as a fallback for XPU unimplemented operators. WARNING: this will bring unexpected performance compared with running natively on XPU.

Python and Pytorch Info

pip3 install torch==2.6.0 --index-url https://download.pytorch.org/whl/test/xpu
>>> import torch
>>> torch.xpu.is_available()
True
>>> torch.xpu.device_count()
12
>>> torch.__version__
'2.6.0+xpu'

Full command and output:

#[🐍 anl_2024_12_release_2](👻 anl_2024_12_release_2)
#[09:25:25 AM][x4204c5s6b0n0][/f/A/f/p/p/torchtune][🌱 main][!?][⏱️ 1m8s]
$ tune run full_finetune_distributed --config llama3_1/8B_full optimizer.fused=False
/lus/flare/projects/Aurora_deployment/foremans/micromamba/envs/anl_2024_12_release_2/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/__init__.py:25: UserWarning: Warning: Cannot load xpu CCL. CCL doesnt work for XPU device due to libintel-ext-pt-gpu.so: cannot open shared object file: No such file or directory
  warnings.warn(f"Warning: Cannot load xpu CCL. CCL doesnt work for XPU device due to {e}")
[2025-01-28 09:25:40][I][ezpz/dist:823] Using device='xpu' with backend='DDP' + 'gloo' for distributed training.
[2025-01-28 09:25:40][I][ezpz/dist:869] ['x4204c5s6b0n0'][0/0]
[2025-01-28 09:25:40][I][config/_utils:28:torchtune.utils._logging] Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 2
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: Meta-Llama-3.1-8B-Instruct/
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  model_type: LLAMA3
  output_dir: /tmp/torchtune/llama3_1_8B/full
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
custom_sharded_layers:
- tok_embeddings
- output
dataset:
  _component_: torchtune.datasets.alpaca_dataset
  packed: false
device: xpu
dtype: bf16
enable_activation_checkpointing: false
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /tmp/torchtune/llama3_1_8B/full/logs
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
optimizer:
  _component_: torch.optim.AdamW
  fused: false
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /tmp/torchtune/llama3_1_8B/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /tmp/torchtune/llama3_1_8B/full/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 3
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: Meta-Llama-3.1-8B-Instruct/original/tokenizer.model

[2025-01-28 09:25:40][I][recipes/full_finetune_distributed:141:__main__] log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False.
/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/torchtune/training/checkpointing/_checkpoint_client.py:75: FutureWarning: get_world_size_and_rank is deprecated and will be removed in future versions. `get_world_size_and_rank` will move to `torchtune.utils._device` in future releases. Please use `torchtune.utils.get_world_size_and_rank` instead.
  _, self._rank = training.get_world_size_and_rank()
[2025-01-28 09:25:40][D][training/seed:60:torchtune.utils._logging] Setting manual seed to local seed 158846877. Local seed is seed + rank = 158846877 + 0
Writing logs to /tmp/torchtune/llama3_1_8B/full/logs/log_1738077940.txt
[2025-01-28 09:26:01][I][recipes/full_finetune_distributed:499:__main__] FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...
/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:863: UserWarning: `_get_pg_default_device` will be deprecated, it only stays for backward-compatiblity reason. If you need to find a device for object collectives, please use `_get_object_coll_device`. If you need to query the device types supported by group, please use `_device_capability(group)`.
  warnings.warn(
/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:907: UserWarning: Multiple backends are registered with this ProcessGroup. We cannot determine which one is the default. Returning cpu. Please consider using other APIs.
  warnings.warn(
[2025-01-28 09:31:12][I][recipes/full_finetune_distributed:568:__main__] Instantiating model and loading checkpoint took 311.13 secs
[2025-01-28 09:31:12][I][training/memory:301:torchtune.utils._logging] Memory stats after model init:
        XPU peak memory allocation: 15.02 GiB
        XPU peak memory reserved: 15.14 GiB
        XPU peak memory active: 15.02 GiB
[2025-01-28 09:31:12][I][recipes/full_finetune_distributed:632:__main__] Optimizer is initialized.
[2025-01-28 09:31:12][I][recipes/full_finetune_distributed:317:__main__] Loss is initialized.
[2025-01-28 09:31:14][I][recipes/full_finetune_distributed:685:__main__] Dataset and Sampler are initialized.
[2025-01-28 09:31:14][I][recipes/full_finetune_distributed:382:__main__] No learning rate scheduler configured. Using constant learning rate.
[2025-01-28 09:31:14][W][training/_profiler:53:torchtune.utils._logging]  Profiling disabled.
[2025-01-28 09:31:14][I][recipes/full_finetune_distributed:467:__main__]  Profiler config after instantiation: {'enabled': False}
  0%|                                                                                                                                                                                                                                                       | 0/26001 [00:00<?, ?it/s][rank0]: Traceback (most recent call last):
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/bin/tune", line 8, in <module>
[rank0]:     sys.exit(main())
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/torchtune/_cli/tune.py", line 49, in main
[rank0]:     parser.run(args)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/torchtune/_cli/tune.py", line 43, in run
[rank0]:     args.func(args)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/torchtune/_cli/run.py", line 214, in _run_cmd
[rank0]:     self._run_single_device(args, is_builtin=is_builtin)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/torchtune/_cli/run.py", line 108, in _run_single_device
[rank0]:     runpy.run_path(str(args.recipe), run_name="__main__")
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/micromamba/envs/anl_2024_12_release_2/lib/python3.10/runpy.py", line 289, in run_path
[rank0]:     return _run_module_code(code, init_globals, run_name,
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/micromamba/envs/anl_2024_12_release_2/lib/python3.10/runpy.py", line 96, in _run_module_code
[rank0]:     _run_code(code, mod_globals, init_globals,
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/micromamba/envs/anl_2024_12_release_2/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]:     exec(code, run_globals)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/recipes/full_finetune_distributed.py", line 928, in <module>
[rank0]:     sys.exit(recipe_main())
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank0]:     sys.exit(recipe_main(conf))
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/recipes/full_finetune_distributed.py", line 923, in recipe_main
[rank0]:     recipe.train()
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/recipes/full_finetune_distributed.py", line 749, in train
[rank0]:     logits = self._model(**batch)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 173
9, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]:     return inner()
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1772, in inner
[rank0]:     args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-pa
ckages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 71, in fsdp_hook_wrapper
[rank0]:     return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 239, in _pre_forward
[rank0]:     args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 334, in pre_forward
[rank0]:     self.unshard(self.unshard_async_op)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 263, in unshard
[rank0]:     self._all_gather_result = foreach_all_gather(
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 153, in foreach_all_gather
[rank0]:     all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/_ops.py", line 1123, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]: NotImplementedError: The operator 'fsdp::all_gather_copy_in' is not currently implemented for the XPU device. Please open a feature on https://github.com/intel/torch-xpu-ops/issues. You can set the environment variable `PYTORCH_ENABLE_XPU_FALLBACK=1` to use the CPU implementation as a fallback for XPU unimplemented operators. WARNING: this will bring unexpected performance compared with running natively on XPU.

Alternatives

No response

Additional context

No response

@saforem2 saforem2 changed the title `fsdp::all_gather_copy_in' not currently implemented for the XPU device fsdp::all_gather_copy_in not currently implemented for the XPU device Jan 28, 2025
@saforem2
Copy link
Author

Update with PYTORCH_ENABLE_XPU_FALLBACK=1

Trying again with the suggested environment variable PYTORCH_ENABLE_XPU_FALLBACK=1 set, I now receieve:

[rank0]: RuntimeError: No backend type associated with device type xpu

Full command and Output

#[🐍 anl_2024_12_release_2](👻 anl_2024_12_release_2)
#[09:35:04 AM][x4204c5s6b0n0][/f/A/f/p/p/torchtune][🌱 main][!?][⏱️ 5m47s]
$ PYTORCH_ENABLE_XPU_FALLBACK=1 tune run full_finetune_distributed --config llama3_1/8B_full optimizer.fused=False
/lus/flare/projects/Aurora_deployment/foremans/micromamba/envs/anl_2024_12_release_2/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/__init__.py:25: UserWarning: Warning: Cannot load xpu CCL. CCL doesnt work for XPU device due to libintel-ext-pt-gpu.so: cannot open shared object file: No such file or directory
  warnings.warn(f"Warning: Cannot load xpu CCL. CCL doesnt work for XPU device due to {e}")
[2025-01-28 09:35:38][I][ezpz/dist:823] Using device='xpu' with backend='DDP' + 'gloo' for distributed training.
[2025-01-28 09:35:38][I][ezpz/dist:869] ['x4204c5s6b0n0'][0/0]
[2025-01-28 09:35:38][I][config/_utils:28:torchtune.utils._logging] Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 2
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: Meta-Llama-3.1-8B-Instruct/
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  model_type: LLAMA3
  output_dir: /tmp/torchtune/llama3_1_8B/full
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
custom_sharded_layers:
- tok_embeddings
- output
dataset:
  _component_: torchtune.datasets.alpaca_dataset
  packed: false
device: xpu
dtype: bf16
enable_activation_checkpointing: false
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /tmp/torchtune/llama3_1_8B/full/logs
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
optimizer:
  _component_: torch.optim.AdamW
  fused: false
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /tmp/torchtune/llama3_1_8B/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /tmp/torchtune/llama3_1_8B/full/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 3
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: Meta-Llama-3.1-8B-Instruct/original/tokenizer.model

[2025-01-28 09:35:38][I][recipes/full_finetune_distributed:141:__main__] log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False.
/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/torchtune/training/checkpointing/_checkpoint_client.py:75: FutureWarning: get_world_size_and_rank is deprecated and will be removed in future versions. `get_world_size_and_rank` will move to `torchtune.utils._device` in future releases. Please use `torchtune.utils.get_world_size_and_rank` instead.
  _, self._rank = training.get_world_size_and_rank()
[2025-01-28 09:35:38][D][training/seed:60:torchtune.utils._logging] Setting manual seed to local seed 3173661602. Local seed is seed + rank = 3173661602 + 0
Writing logs to /tmp/torchtune/llama3_1_8B/full/logs/log_1738078538.txt
[2025-01-28 09:35:55][I][recipes/full_finetune_distributed:499:__main__] FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...
/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:863: UserWarning: `_get_pg_default_device` will be deprecated, it only stays for backward-compatiblity reason. If you need to find a device for object collectives, please use `_get_object_coll_device`. If you need to query the device types supported by group, please use `_device_capability(group)`.
  warnings.warn(
/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:907: UserWarning: Multiple backends are registered with this ProcessGroup. We cannot determine which one is the default. Returning cpu. Please consider using other APIs.
  warnings.warn(
[2025-01-28 09:40:18][I][recipes/full_finetune_distributed:568:__main__] Instantiating model and loading checkpoint took 263.44 secs
[2025-01-28 09:40:18][I][training/memory:301:torchtune.utils._logging] Memory stats after model init:
        XPU peak memory allocation: 15.02 GiB
        XPU peak memory reserved: 15.14 GiB
        XPU peak memory active: 15.02 GiB
[2025-01-28 09:40:18][I][recipes/full_finetune_distributed:632:__main__] Optimizer is initialized.
[2025-01-28 09:40:18][I][recipes/full_finetune_distributed:317:__main__] Loss is initialized.
[2025-01-28 09:40:20][I][recipes/full_finetune_distributed:685:__main__] Dataset and Sampler are initialized.
[2025-01-28 09:40:20][I][recipes/full_finetune_distributed:382:__main__] No learning rate scheduler configured. Using constant learning rate.
[2025-01-28 09:40:20][W][training/_profiler:53:torchtune.utils._logging]  Profiling disabled.
[2025-01-28 09:40:20][I][recipes/full_finetune_distributed:467:__main__]  Profiler config after instantiation: {'enabled': False}
  0%|                                                                                                                                                          | 0/26001 [00:00<?, ?it/s][rank0]:[W128 09:40:21.158240118 RegisterXPU.cpp:45778] Warning: Aten Op fallback from XPU to CPU happends. This may have performance implications. If need debug the fallback ops please set environment variable `PYTORCH_DEBUG_XPU_FALLBACK=1`  (function operator())
[rank0]: Traceback (most recent call last):
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/bin/tune", line 8, in <module>
[rank0]:     sys.exit(main())
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/torchtune/_cli/tune.py", line 49, in main
[rank0]:     parser.run(args)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/torchtune/_cli/tune.py", line 43, in run
[rank0]:     args.func(args)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/torchtune/_cli/run.py", line 214, in _run_cmd
[rank0]:     self._run_single_device(args, is_builtin=is_builtin)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/torchtune/_cli/run.py", line 108, in _run_single_device
[rank0]:     runpy.run_path(str(args.recipe), run_name="__main__")
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/micromamba/envs/anl_2024_12_release_2/lib/python3.10/runpy.py", line 289, in run_path
[rank0]:     return _run_module_code(code, init_globals, run_name,
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/micromamba/envs/anl_2024_12_release_2/lib/python3.10/runpy.py", line 96, in _run_module_code
[rank0]:     _run_code(code, mod_globals, init_globals,
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/micromamba/envs/anl_2024_12_release_2/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]:     exec(code, run_globals)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/recipes/full_finetune_distributed.py", line 928, in <module>
[rank0]:     sys.exit(recipe_main())
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank0]:     sys.exit(recipe_main(conf))
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/recipes/full_finetune_distributed.py", line 923, in recipe_main
[rank0]:     recipe.train()
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/recipes/full_finetune_distributed.py", line 749, in train
[rank0]:     logits = self._model(**batch)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]:     return inner()
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1772, in inner
[rank0]:     args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 71, in fsdp_hook_wrapper
[rank0]:     return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 239, in _pre_forward
[rank0]:     args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 334, in pre_forward
[rank0]:     self.unshard(self.unshard_async_op)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 263, in unshard
[rank0]:     self._all_gather_result = foreach_all_gather(
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 165, in foreach_all_gather
[rank0]:     all_gather_work = dist.all_gather_into_tensor(
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/pytorch/torchtune/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3798, in all_gather_into_tensor
[rank0]:     work = group._allgather_base(output_tensor, input_tensor, opts)
[rank0]: RuntimeError: No backend type associated with device type xpu
  0%|                                                                                                                                                          | 0/26001 [00:00<?, ?it/s]
[1]    143628 exit 1     PYTORCH_ENABLE_XPU_FALLBACK=1 tune run full_finetune_distributed --config
took: 0h:05m:13s

@saforem2
Copy link
Author

Okay, I was able to get past the RuntimeError from my previous attempt above.

Instead, now I see the following warning:

/lus/flare/projects/Aurora_deployment/foremans/micromamba/envs/anl_2024_12_release_2/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/__init__.py:25: UserWarning: Warning: Cannot load xpu CCL. CCL doesn't work for XPU device due to libintel-ext-pt-gpu.so: cannot open shared object file: No such file or directory
  warnings.warn(f"Warning: Cannot load xpu CCL. CCL doesn't work for XPU device due to {e}")

Before finally crashing with:

[2025-01-28 10:06:31][I][ezpz/dist:831] Using device='xpu' with backend='DDP' + 'ccl' for distributed training.
[2025-01-28 10:06:31][I][ezpz/dist:877] ['x4204c5s3b0n0'][ 0/47]
[2025-01-28 10:06:31][I][ezpz/test_dist:369:__main__] model=
Network(
  (layers): Sequential(
    (0): Linear(in_features=128, out_features=1024, bias=True)
    (1): Linear(in_features=1024, out_features=512, bias=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): Linear(in_features=128, out_features=128, bias=True)
  )
)
[rank38]: Traceback (most recent call last):
[rank38]:   File "/lus/flare/projects/Aurora_deployment/foremans/micromamba/envs/anl_2024_12_release_2/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank38]:     return _run_code(code, main_globals, None,
[rank38]:   File "/lus/flare/projects/Aurora_deployment/foremans/micromamba/envs/anl_2024_12_release_2/lib/python3.10/runpy.py", line 86, in _run_code
[rank38]:     exec(code, run_globals)
[rank38]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/saforem2/mmm/deps/ezpz/src/ezpz/test_dist.py", line 420, in <module>
[rank38]:     trainer = main()
[rank38]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/saforem2/mmm/deps/ezpz/src/ezpz/test_dist.py", line 405, in main
[rank38]:     trainer = train(config)
[rank38]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/saforem2/mmm/deps/ezpz/src/ezpz/test_dist.py", line 214, in train
[rank38]:     model, optimizer = build_model_and_optimizer(model, backend=config.backend)
[rank38]:   File "/lus/flare/projects/Aurora_deployment/foremans/projects/saforem2/mmm/deps/ezpz/src/ezpz/test_dist.py", line 373, in build_model_and_optimizer
[rank38]:     model = DDP(model, device_ids=[ezpz.get_local_rank()])
[rank38]:   File "/flare/Aurora_deployment/foremans/projects/saforem2/mmm/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 825, in __init__
[rank38]:     _verify_param_shape_across_processes(self.process_group, parameters)
[rank38]:   File "/flare/Aurora_deployment/foremans/projects/saforem2/mmm/venvs/anl_2024_12_release_2/lib/python3.10/site-packages/torch/distributed/utils.py", line 294, in _verify_param_shape_across_processes
[rank38]:     return dist._verify_params_across_processes(process_group, tensors, logger)
[rank38]: RuntimeError: oneccl_bindings_for_pytorch: allgather isn't implementd on backend [xpu].

@dvrogozh
Copy link
Contributor

@saforem2 : I think that PYTORCH_ENABLE_XPU_FALLBACK=1 won't help you if you plan to use XCCL distributed to run on multi-XPU environment. That's the case when you really need ops implemented for XPU rather than CPU substitutes. At the moment, the latest pytorch (pytorch/pytorch@6371c25) has enabled XCCL, but a number of ops are actually missing. However, some of them are implemented on torch-xpu-ops level, but pytorch side commit pin is not yet updated. If you want to try them out you can try to update commit pin manually:

--- a/third_party/xpu.txt
+++ b/third_party/xpu.txt
@@ -1 +1 @@
-22cc419e4e60f469341712a5a103fa309a7dfd48
+main
  • Make sure to source oneAPI environment variables for your build and runtime environments. I am doing this on my side (note the last line):
. /opt/intel/oneapi/compiler/2025.0/env/vars.sh
. /opt/intel/oneapi/umf/0.9/env/vars.sh
. /opt/intel/oneapi/pti/0.10/env/vars.sh
. /opt/intel/oneapi/ccl/2021.14//env/vars.sh
  • To build pytorch with XCCL with torch-xpu-ops@main, explicitly set USE_XCCL=1 (note a6f4c32):
USE_XCCL=1 python3 setup.py develop
  • When building, check that XCCL is actually getting built:
-- Found XCCL: /opt/intel/oneapi/ccl/2021.14/include;/opt/intel/oneapi/ccl/2021.14/include/oneapi

To check that XCCL is available at runtime:

$ python -c "import torch; print(torch.distributed.distributed_c10d.is_xccl_available())"
True

@saforem2
Copy link
Author

oh awesome, thank you for this! I will work on testing with your changes and report back

@daisyden daisyden added this to the PT2.7 milestone Feb 20, 2025
@zhangxiaoli73
Copy link

@saforem2 We have implemented all collectives in XCCL backend, please try latest stock PyTorch and torch-xpu-ops (building with USE_XCCL=ON). Let me know if you still have any problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants