You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to use NF4Tensor weights in my model and wrap it with DistributedDataParallel, but get the following error:
[rank0]: model = DistributedDataParallel(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/path/to/venv/lib/python3.12/site-packages/torch/nn/parallel/distributed.py", line 837, in __init__
[rank0]: _sync_module_states(
[rank0]: File "/path/to/venv/lib/python3.12/site-packages/torch/distributed/utils.py", line 313, in _sync_module_states
[rank0]: _sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src)
[rank0]: File "/path/to/venv/lib/python3.12/site-packages/torch/distributed/utils.py", line 324, in _sync_params_and_buffers
[rank0]: dist._broadcast_coalesced(
[rank0]: File "/path/to/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/path/to/venv/lib/python3.12/site-packages/torchao/dtypes/nf4tensor.py", line 834, in __torch_dispatch__
[rank0]: raise NotImplementedError(
[rank0]: NotImplementedError: NF4Tensor dispatch: attempting to run aten.cat.default, this is not supported
To replicate:
from torchao.dtypes.nf4tensor import linear_nf4, to_nf4
from torch.nn.parallel import DistributedDataParallel
from torch import nn
import os
import torch
class NF4(nn.Module):
def __init__(
self,
device = None,
):
super().__init__()
self.linear = nn.Linear(512, 512, bias=False, device=device)
self.linear.weight = nn.Parameter(to_nf4(self.linear.weight))
if __name__ == "__main__":
_local_rank = int(os.getenv("LOCAL_RANK", "0"))
_device = f"cuda:{_local_rank}"
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
device_id=torch.device(_local_rank),
)
model = NF4(_device)
model = DistributedDataParallel(model)
torchrun --nproc_per_node=2 script.py
NotImplementedError: NF4Tensor dispatch: attempting to run c10d.broadcast_.default, this is not supported
Is there some way around this issue?
The text was updated successfully, but these errors were encountered:
curious about the use case here. is it finetuning/QLoRA on a llama/transformer-alike model? we didnt support DDP + NF4 because llama/qlora are always parallelized by FSDP first. would you consider FSDP? for DDP, we may need to design a tensor subclass extension point, if the use case is motivating
Yes exactly, using QLoRA finetuning with DDP.
I think FSDP is an option, but even with setting reshard_after_forward=False from my experience DDP is more efficient.
I am trying to use
NF4Tensor
weights in my model and wrap it withDistributedDataParallel
, but get the following error:To replicate:
torchrun --nproc_per_node=2 script.py
NotImplementedError: NF4Tensor dispatch: attempting to run c10d.broadcast_.default, this is not supported
Is there some way around this issue?
The text was updated successfully, but these errors were encountered: