Skip to content

Commit

Permalink
[BE][Easy] enable UFMT for torch/distributed/ (pytorch#128870)
Browse files Browse the repository at this point in the history
Part of pytorch#123062

- pytorch#123062

Pull Request resolved: pytorch#128870
Approved by: https://github.com/fegin
ghstack dependencies: pytorch#128868, pytorch#128869
  • Loading branch information
XuehaiPan authored and pytorchmergebot committed Jun 18, 2024
1 parent 3b798df commit a0e1e20
Show file tree
Hide file tree
Showing 36 changed files with 583 additions and 298 deletions.
27 changes: 0 additions & 27 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1389,33 +1389,6 @@ exclude_patterns = [
'torch/contrib/_tensorboard_vis.py',
"torch/cuda/_gpu_trace.py",
'torch/cuda/_memory_viz.py', # mypy: Value of type "object" is not indexable
'torch/distributed/__init__.py',
'torch/distributed/_composable_state.py',
'torch/distributed/_sharded_tensor/__init__.py',
'torch/distributed/_sharding_spec/__init__.py',
'torch/distributed/_tools/__init__.py',
'torch/distributed/_tools/memory_tracker.py',
'torch/distributed/argparse_util.py',
'torch/distributed/c10d_logger.py',
'torch/distributed/collective_utils.py',
'torch/distributed/constants.py',
'torch/distributed/distributed_c10d.py',
'torch/distributed/examples/memory_tracker_example.py',
'torch/distributed/launch.py',
'torch/distributed/launcher/__init__.py',
'torch/distributed/launcher/api.py',
'torch/distributed/logging_handlers.py',
'torch/distributed/nn/__init__.py',
'torch/distributed/nn/api/__init__.py',
'torch/distributed/nn/api/remote_module.py',
'torch/distributed/nn/functional.py',
'torch/distributed/nn/jit/__init__.py',
'torch/distributed/nn/jit/instantiator.py',
'torch/distributed/nn/jit/templates/__init__.py',
'torch/distributed/nn/jit/templates/remote_module_template.py',
'torch/distributed/remote_device.py',
'torch/distributed/rendezvous.py',
'torch/distributed/run.py',
'torch/fft/__init__.py',
'torch/func/__init__.py',
'torch/futures/__init__.py',
Expand Down
64 changes: 30 additions & 34 deletions torch/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# mypy: allow-untyped-defs
import sys
import pdb
import sys

import torch


def is_available() -> bool:
"""
Return ``True`` if the distributed package is available.
Expand All @@ -29,31 +30,31 @@ def is_available() -> bool:

if is_available():
from torch._C._distributed_c10d import (
Store,
FileStore,
TCPStore,
ProcessGroup as ProcessGroup,
Backend as _Backend,
PrefixStore,
Reducer,
Logger,
BuiltinCommHookType,
GradBucket,
Work as _Work,
_DEFAULT_FIRST_BUCKET_BYTES,
_register_comm_hook,
_register_builtin_comm_hook,
_broadcast_coalesced,
_compute_bucket_assignment_by_size,
_verify_params_across_processes,
_ControlCollectives,
_DEFAULT_FIRST_BUCKET_BYTES,
_make_nccl_premul_sum,
_register_builtin_comm_hook,
_register_comm_hook,
_StoreCollectives,
_test_python_store,
_verify_params_across_processes,
Backend as _Backend,
BuiltinCommHookType,
DebugLevel,
FileStore,
get_debug_level,
GradBucket,
Logger,
PrefixStore,
ProcessGroup as ProcessGroup,
Reducer,
set_debug_level,
set_debug_level_from_env,
_make_nccl_premul_sum,
_ControlCollectives,
_StoreCollectives,
Store,
TCPStore,
Work as _Work,
)

class _DistributedPdb(pdb.Pdb):
Expand All @@ -63,10 +64,11 @@ class _DistributedPdb(pdb.Pdb):
Usage:
_DistributedPdb().set_trace()
"""

def interaction(self, *args, **kwargs):
_stdin = sys.stdin
try:
sys.stdin = open('/dev/stdin')
sys.stdin = open("/dev/stdin")
pdb.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin
Expand Down Expand Up @@ -98,37 +100,31 @@ def breakpoint(rank: int = 0):
del guard

if sys.platform != "win32":
from torch._C._distributed_c10d import (
HashStore,
_round_robin_process_groups,
)
from torch._C._distributed_c10d import _round_robin_process_groups, HashStore

from .distributed_c10d import * # noqa: F403
from .device_mesh import DeviceMesh, init_device_mesh

# Variables prefixed with underscore are not auto imported
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
# this.

from .distributed_c10d import * # noqa: F403
from .distributed_c10d import (
_all_gather_base,
_reduce_scatter_base,
_create_process_group_wrapper,
_rank_not_in_group,
_coalescing_manager,
_CoalescingManager,
_create_process_group_wrapper,
_get_process_group_name,
_rank_not_in_group,
_reduce_scatter_base,
get_node_local_rank,
)

from .remote_device import _remote_device
from .rendezvous import (
rendezvous,
_create_store_from_options,
register_rendezvous_handler,
rendezvous,
)

from .remote_device import _remote_device
from .device_mesh import init_device_mesh, DeviceMesh

set_debug_level_from_env()

else:
Expand Down
1 change: 1 addition & 0 deletions torch/distributed/_composable/fsdp/_fsdp_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.distributed as dist
from torch.distributed._tensor import DTensor
from torch.distributed.distributed_c10d import ReduceOp

from ._fsdp_common import (
_get_dim0_padded_size,
_raise_assert_with_print,
Expand Down
1 change: 0 additions & 1 deletion torch/distributed/_composable/fsdp/_fsdp_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import math
import traceback

from dataclasses import dataclass
from enum import auto, Enum
from typing import Any, cast, List, Optional
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_composable/fsdp/_fsdp_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import torch
import torch.distributed as dist
import torch.nn as nn

from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh
from torch.distributed.device_mesh import _get_device_handle
from torch.utils._python_dispatch import is_traceable_wrapper_subclass

from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo
from ._fsdp_state import _get_module_fsdp_state

Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/_composable/fsdp/_fsdp_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import torch
import torch._dynamo.compiled_autograd as ca
import torch.nn as nn

from torch._prims_common import make_contiguous_strides_for
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed._tensor.device_mesh import _mesh_resources
from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta

from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
from ._fsdp_common import (
_chunk_with_empty,
Expand All @@ -24,6 +24,7 @@
HSDPMeshInfo,
)


"""
[Note: FSDP tensors]
FSDP considers the following tensors:
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/_composable/fsdp/_fsdp_param_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import contextlib

from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple

import torch
Expand All @@ -11,6 +10,7 @@
from torch.profiler import record_function
from torch.utils._pytree import tree_flatten, tree_unflatten
from torch.utils.hooks import RemovableHandle

from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
from ._fsdp_collectives import (
AllGatherResult,
Expand All @@ -21,6 +21,7 @@
from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo, TrainingState
from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState


_ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict


Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/_composable/fsdp/_fsdp_state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import functools

from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING

import torch
Expand All @@ -13,10 +12,12 @@
)
from torch.distributed.utils import _to_kwargs
from torch.utils._pytree import tree_flatten, tree_map

from ._fsdp_api import MixedPrecisionPolicy
from ._fsdp_common import _cast_fp_tensor, TrainingState
from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup


if TYPE_CHECKING:
from ._fsdp_param import FSDPParam

Expand Down
1 change: 0 additions & 1 deletion torch/distributed/_composable/fsdp/fully_shard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import functools

from typing import Any, cast, Iterable, List, NoReturn, Optional, Union

import torch
Expand Down
1 change: 0 additions & 1 deletion torch/distributed/_composable/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch.distributed._composable_state import _get_module_state, _insert_module_state
from torch.distributed.fsdp._common_utils import _FSDPState
from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo

from torch.distributed.fsdp._init_utils import (
_init_buffer_state,
_init_core_state,
Expand Down
1 change: 1 addition & 0 deletions torch/distributed/_composable/replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .contract import _get_registry, contract


_ROOT_MODULE_PREFIX = ""


Expand Down
3 changes: 1 addition & 2 deletions torch/distributed/_cuda_p2p/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# mypy: allow-untyped-defs
from collections import defaultdict
from contextlib import contextmanager

from functools import partial
from typing import Callable, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union

import torch
import torch.distributed._functional_collectives as funcol

import torch.distributed.distributed_c10d as c10d


if TYPE_CHECKING:
from torch._C._distributed_c10d import _DistributedBackendOptions, Backend

Expand Down
2 changes: 2 additions & 0 deletions torch/distributed/_functional_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from . import _functional_collectives_impl as fun_col_impl


try:
from torch.utils._cxx_pytree import tree_map_only
except ImportError:
Expand Down Expand Up @@ -1134,6 +1135,7 @@ def all_gather_inplace(
reduce_scatter_tensor as legacy_reducescatter,
)


# This dict should contain sets of functions that dynamo is allowed to remap.
# Functions in this set should accept the same args/kwargs 1:1 as their mapping.
traceable_collective_remaps = {
Expand Down
1 change: 1 addition & 0 deletions torch/distributed/_functional_collectives_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.distributed.distributed_c10d as c10d


"""
This file contains the op impls for the legacy (c10d_functional) functional collectives.
These impls simply call into the native (_c10d_functional) functional collectives.
Expand Down
7 changes: 5 additions & 2 deletions torch/distributed/_sharded_tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Keep old package for BC purposes, this file should be removed once
# everything moves to the `torch.distributed._shard` package.
import sys
import torch
import warnings

import torch
from torch.distributed._shard.sharded_tensor import * # noqa: F403


with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
Expand All @@ -15,4 +16,6 @@
stacklevel=2,
)

sys.modules['torch.distributed._sharded_tensor'] = torch.distributed._shard.sharded_tensor
sys.modules[
"torch.distributed._sharded_tensor"
] = torch.distributed._shard.sharded_tensor
7 changes: 5 additions & 2 deletions torch/distributed/_sharding_spec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Keep old package for BC purposes, this file should be removed once
# everything moves to the `torch.distributed._shard` package.
import sys
import torch
import warnings

import torch
from torch.distributed._shard.sharding_spec import * # noqa: F403


with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
Expand All @@ -16,4 +17,6 @@
)

import torch.distributed._shard.sharding_spec as _sharding_spec
sys.modules['torch.distributed._sharding_spec'] = _sharding_spec


sys.modules["torch.distributed._sharding_spec"] = _sharding_spec
1 change: 1 addition & 0 deletions torch/distributed/_state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch.nn.functional as F
from torch.distributed._functional_collectives import AsyncCollectiveTensor


if dist.is_available() or TYPE_CHECKING:
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import ShardedTensor
Expand Down
Loading

0 comments on commit a0e1e20

Please sign in to comment.