Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

meta device issue with float8 delayed scale #654

Open
weifengpy opened this issue Oct 25, 2024 · 8 comments
Open

meta device issue with float8 delayed scale #654

weifengpy opened this issue Oct 25, 2024 · 8 comments
Labels
bug Something isn't working

Comments

@weifengpy
Copy link
Contributor

weifengpy commented Oct 25, 2024

repro:

CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.scaling_type_weight "delayed" --metrics.log_freq 1 --training.steps 3 --checkpoint.enable_checkpoint --checkpoint.interval 2
  traceback : Traceback (most recent call last):
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper   
      return f(*args, **kwargs)
    File "/data/users/weif/torchtitan/train.py", line 301, in main
      pred = model(input_ids)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
      return inner()
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
      result = forward_call(*args, **kwargs)
    File "/data/users/weif/torchtitan/torchtitan/models/llama/model.py", line 439, in forward
      h = layer(h, self.freqs_cis)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
      return inner()
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1769, in inner
      args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 67, in fsdp_hook_wrapper
      return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 629, in _fn
      return fn(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_state.py", line 234, in _pre_forward
      args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 314, in pre_forward
      self.unshard(self.unshard_async_op)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 243, in unshard
      self._all_gather_result = foreach_all_gather(
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
      return func(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_collectives.py", line 139, in foreach_all_gather
      param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params)    
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
      return func(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_collectives.py", line 217, in _get_param_all_gather_inputs
      param_all_gather_inputs[i] = fsdp_param.all_gather_inputs
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 702, in all_gather_inputs
      ) = sharded_local_tensor.fsdp_pre_all_gather(self.mesh_info.mesh)
    File "/data/users/weif/ao/torchao/float8/fsdp_utils.py", line 408, in fsdp_pre_all_gather
      float8_tensor = hp_tensor_to_float8_delayed(
    File "/data/users/weif/ao/torchao/float8/float8_scaling_utils.py", line 105, in hp_tensor_to_float8_delayed
      return hp_tensor_and_scale_to_float8(
    File "/data/users/weif/ao/torchao/float8/float8_tensor.py", line 254, in hp_tensor_and_scale_to_float8
      return _ToFloat8ConstrFunc.apply(
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
      return super().apply(*args, **kwargs)  # type: ignore[misc]
    File "/data/users/weif/ao/torchao/float8/float8_tensor.py", line 170, in forward
      tensor_scaled = tensor.to(torch.float32) * scale
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 289, in _fn
      result = fn(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 141, in _fn
      result = fn(**bound.arguments)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_refs/__init__.py", line 1064, in _ref
      output = prim(a, b)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_refs/__init__.py", line 1671, in mul
      return prims.mul(a, b)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_ops.py", line 723, in __call__
      return self._op(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_library/fake_impl.py", line 95, in meta_kernel
      return fake_impl_holder.kernel(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_library/utils.py", line 20, in __call__
      return self.func(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/library.py", line 1190, in inner
      return func(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 614, in fake_impl
      return self._abstract_fn(*args, **kwargs)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_prims/__init__.py", line 402, in _prim_elementwise_meta
      utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
    File "/home/xxx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 742, in check_same_device
      raise RuntimeError(msg)
  RuntimeError: Tensor on device meta is not on the expected device cuda:1!
@weifengpy
Copy link
Contributor Author

cc @vkuzo

@tianyu-l tianyu-l added the bug Something isn't working label Oct 25, 2024
@vkuzo
Copy link
Contributor

vkuzo commented Oct 25, 2024

without debugging, my guess would be something like:

  1. model is created on device meta
  2. checkpoint is loaded with device cuda, but it does not have the extra buffers for delayed scaling

I can take a look next week, unless someone gets to it faster

@weifengpy
Copy link
Contributor Author

checkpoint is loaded with device cuda, but it does not have the extra buffers for delayed scaling

if running the repo for the 1st time, torchtitan/output/checkpoint folder will be empty. the model won't load checkponits but the error is still there. We do meta init and call init_weights to move model from meta to cuda. buffers for delayed scaling might need some treatment

I can take a look next week, unless someone gets to it faster

thanks!

@vkuzo
Copy link
Contributor

vkuzo commented Oct 25, 2024

I see, then this line is relevant: https://github.com/pytorch/ao/blob/e85c1a318b06bbdb3b8c7f92f3257999864446b0/torchao/float8/float8_linear.py#L648

We'll have to think if we can figure out to do this automatically without introducing one more API. If not, we'll have to design such as API.

@weifengpy
Copy link
Contributor Author

weifengpy commented Oct 25, 2024

I see, then this line is relevant: https://github.com/pytorch/ao/blob/e85c1a318b06bbdb3b8c7f92f3257999864446b0/torchao/float8/float8_linear.py#L648

We'll have to think if we can figure out to do this automatically without introducing one more API. If not, we'll have to design such as API.

I see. it sounds plausible

vkuzo added a commit to pytorch/ao that referenced this issue Nov 15, 2024
Summary:

Context: pytorch/torchtitan#654

If the user has delayed scaling and FSDP float8 all-gather on, there is
a subtle bug that can happen if the user calls
`model.to_empty(device="cuda")`:
1. to_empty recreates the buffers for tracking weight amax and scale
2. (1) leaves the buffers pointed to by Float8Linear.weight._amax_buffer, etc orphaned, because they don't participate in `to_empty`

I couldn't think of an easy and clean way to auto-fix this since we can't expect
`torch.nn.Module` to know that our logic has multiple references to the same
buffer, so exposing a private API for now until we can think of something better.

With the current fix, the user can then call
`_maybe_fixup_delayed_scaling_buffers` manually to relink the buffers to
the correct new versions.

Test Plan: CI

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo
Copy link
Contributor

vkuzo commented Nov 15, 2024

I really don't love this solution, but we could do something like this: pytorch/ao#1292. Thoughts?

vkuzo added a commit to pytorch/ao that referenced this issue Nov 15, 2024
Summary:

Context: pytorch/torchtitan#654

If the user has delayed scaling and FSDP float8 all-gather on, there is
a subtle bug that can happen if the user calls
`model.to_empty(device="cuda")`:
1. to_empty recreates the buffers for tracking weight amax and scale
2. (1) leaves the buffers pointed to by Float8Linear.weight._amax_buffer, etc orphaned, because they don't participate in `to_empty`

I couldn't think of an easy and clean way to auto-fix this since we can't expect
`torch.nn.Module` to know that our logic has multiple references to the same
buffer, so exposing a private API for now until we can think of something better.

With the current fix, the user can then call
`_maybe_fixup_delayed_scaling_buffers` manually to relink the buffers to
the correct new versions.

Test Plan: CI

Reviewers:

Subscribers:

Tasks:

Tags:
@tianyu-l tianyu-l added this to the torchtitan release 1.0 milestone Nov 19, 2024
@weifengpy
Copy link
Contributor Author

I really don't love this solution, but we could do something like this: pytorch/ao#1292. Thoughts?

thanks for the fix!

@vkuzo
Copy link
Contributor

vkuzo commented Nov 21, 2024

opening as the fix isn't landed yet :)

@vkuzo vkuzo reopened this Nov 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants