Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jan 11, 2024
1 parent 574f2b0 commit bb03782
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,3 @@ repos:
args:
- --fix
- --exit-non-zero-on-fix
exclude: ^src/nanotron/distributed/__init__.py$
2 changes: 2 additions & 0 deletions src/nanotron/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from nanotron.distributed.parallel_context import ParallelContext

__all__ = ["ParallelContext"]
9 changes: 9 additions & 0 deletions tests/test_distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
import pytest
import torch.distributed as dist
from helpers.utils import (
available_gpus,
get_all_3d_configurations,
Expand All @@ -14,6 +16,13 @@ def _test_init_parallel_context(parallel_context: ParallelContext):
assert isinstance(parallel_context.pp_pg, ProcessGroup) if parallel_context.pipeline_parallel_size > 1 else True
assert isinstance(parallel_context.dp_pg, ProcessGroup) if parallel_context.data_parallel_size > 1 else True

world_rank = dist.get_rank(parallel_context.world_pg)
ranks3d = parallel_context.get_3d_ranks(world_rank)
assert type(ranks3d) and len(ranks3d)

assert isinstance(parallel_context.world_rank_matrix, np.ndarray)
assert isinstance(parallel_context.world_ranks_to_pg, dict)


@pytest.mark.parametrize(
"tp,dp,pp",
Expand Down

0 comments on commit bb03782

Please sign in to comment.