From e0314aa1ba37ea30bc75f2282860b5611339c415 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 10 Jan 2024 11:12:38 +0000 Subject: [PATCH 1/9] add ParallelContext --- .pre-commit-config.yaml | 1 + src/nanotron/distributed/__init__.py | 2 + src/nanotron/distributed/parallel_context.py | 356 +++++++++++++++++++ src/nanotron/distributed/parallel_mode.py | 9 + tests/helpers/utils.py | 14 +- tests/test_distributed.py | 62 ++++ 6 files changed, 440 insertions(+), 4 deletions(-) create mode 100644 src/nanotron/distributed/__init__.py create mode 100644 src/nanotron/distributed/parallel_context.py create mode 100644 src/nanotron/distributed/parallel_mode.py create mode 100644 tests/test_distributed.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a6045cfb..7becc218 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,3 +19,4 @@ repos: args: - --fix - --exit-non-zero-on-fix + exclude: ^src/nanotron/distributed/__init__.py$ diff --git a/src/nanotron/distributed/__init__.py b/src/nanotron/distributed/__init__.py new file mode 100644 index 00000000..33a98f44 --- /dev/null +++ b/src/nanotron/distributed/__init__.py @@ -0,0 +1,2 @@ +from nanotron.distributed.parallel_context import ParallelContext +from nanotron.distributed.parallel_mode import ParallelMode diff --git a/src/nanotron/distributed/parallel_context.py b/src/nanotron/distributed/parallel_context.py new file mode 100644 index 00000000..3166168d --- /dev/null +++ b/src/nanotron/distributed/parallel_context.py @@ -0,0 +1,356 @@ +import os +from typing import Dict, Literal, Tuple + +import numpy as np +import torch +import torch.distributed as dist +from nanotron.distributed.parallel_mode import ParallelMode + +DistributedBackend = Literal["gloo", "mpi", "nccl"] +RanksToDevice = Dict[ParallelMode, int] + + +class ParallelContext: + # @classmethod + # def from_torch( + # cls, + # tensor_parallel_size: int, + # pipeline_parallel_size: int, + # data_parallel_size: int, + # backend: DistributedBackend = "nccl", + # ): + # """Initialize parallel context based on the environment variables defined by torchrun.""" + # rank = int(os.environ["RANK"]) + # local_rank = int(os.environ["LOCAL_RANK"]) + # world_size = int(os.environ["WORLD_SIZE"]) + # local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + # host = os.environ["MASTER_ADDR"] + # # TODO(xrsrke): make it auto search for ports? + # port = int(os.environ["MASTER_PORT"]) + + # return cls( + # rank=rank, + # local_rank=local_rank, + # world_size=world_size, + # local_world_size=local_world_size, + # host=host, + # port=port, + # backend=backend, + # tensor_parallel_size=tensor_parallel_size, + # pipeline_parallel_size=pipeline_parallel_size, + # data_parallel_size=data_parallel_size, + # ) + + def __init__( + self, + tensor_parallel_size: int, + pipeline_parallel_size: int, + data_parallel_size: int, + backend: DistributedBackend = "nccl", + ): + """Initialize parallel context.""" + num_gpus_per_model = tensor_parallel_size * pipeline_parallel_size + world_size = int(os.environ["WORLD_SIZE"]) + + assert ( + world_size % data_parallel_size == 0 + ), "The total number of processes must be divisible by the data parallel size." + assert world_size % num_gpus_per_model == 0, ( + "The total number of processes must be divisible by" + "the number of GPUs per model (tensor_parallel_size * pipeline_parallel_size)." + ) + assert num_gpus_per_model * data_parallel_size == world_size, ( + "The number of process requires to train all replicas", + "must be equal to the world size.", + ) + + if not dist.is_available(): + raise ValueError("`torch.distributed is not available as a package, please install it.") + + self.tensor_parallel_size = tensor_parallel_size + self.pipeline_parallel_size = pipeline_parallel_size + self.data_parallel_size = data_parallel_size + + # self._global_ranks = {} + # self._local_ranks = {} + # self._world_sizes = {} + self._groups = {} + # self._ranks_in_group = {} + # self._ranks_to_device = {} + + # self.local_rank = local_rank + # self.local_world_size = local_world_size + + self.set_device() + + if not dist.is_initialized(): + rank = int(os.environ["RANK"]) + # local_rank = int(os.environ["LOCAL_RANK"]) + # local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + host = os.environ["MASTER_ADDR"] + # TODO(xrsrke): make it auto search for ports? + port = int(os.environ["MASTER_PORT"]) + self.init_global_dist(rank, world_size, backend, host, port) + + self.init_parallel_groups() + dist.barrier() + + def init_global_dist(self, rank: int, world_size: int, backend: DistributedBackend, host: str, port: int): + """Initialize the global distributed group. + + Args: + rank (int): global rank + world_size (int): global world size + backend (DistributedBackend): distributed backend + host (str): communication host + port (int): communication port + """ + assert backend == "nccl", "Only nccl backend is supported for now." + + init_method = f"tcp://{host}:{port}" + dist.init_process_group( + rank=rank, world_size=world_size, backend=backend, init_method=init_method, timeout=dist.default_pg_timeout + ) + ranks = list(range(world_size)) + process_group = dist.new_group( + ranks=ranks, + backend=dist.get_backend(), + ) + # self._register_dist(rank, world_size, process_group, ranks_in_group=ranks, parallel_mode=ParallelMode.GLOBAL) + # self.add_group(ParallelMode.GLOBAL, process_group) + self.world_pg = process_group + # self.add_global_rank(ParallelMode.GLOBAL, rank) + + def init_parallel_groups(self): + """Initialize 3D parallelism's all process groups.""" + # rank = self.get_global_rank() + + # NOTE: ensure all processes have joined the global group + # before creating other groups + dist.barrier(group=self.world_pg) + + # rank = self.get_global_rank() + rank = int(os.environ["RANK"]) + # world_size = self.get_world_size(ParallelMode.GLOBAL) + world_size = int(os.environ["WORLD_SIZE"]) + ranks = np.arange(0, world_size).reshape( + (self.pipeline_parallel_size, self.data_parallel_size, self.tensor_parallel_size) + ) + world_ranks_to_pg = {} + + tp_pg: dist.ProcessGroup + ranks_with_tp_last = ranks.reshape( + (self.pipeline_parallel_size * self.data_parallel_size, self.tensor_parallel_size) + ) + for tp_ranks in ranks_with_tp_last: + sorted_ranks = tuple(sorted(tp_ranks)) + if sorted_ranks not in world_ranks_to_pg: + new_group = dist.new_group(ranks=tp_ranks) + world_ranks_to_pg[sorted_ranks] = new_group + else: + new_group = world_ranks_to_pg[sorted_ranks] + if rank in tp_ranks: + tp_pg = new_group + + dp_pg: dist.ProcessGroup + ranks_with_dp_last = ranks.transpose((0, 2, 1)).reshape( + (self.pipeline_parallel_size * self.tensor_parallel_size, self.data_parallel_size) + ) + for dp_ranks in ranks_with_dp_last: + sorted_ranks = tuple(sorted(dp_ranks)) + if sorted_ranks not in world_ranks_to_pg: + new_group = dist.new_group(ranks=dp_ranks) + world_ranks_to_pg[sorted_ranks] = new_group + else: + new_group = world_ranks_to_pg[sorted_ranks] + if rank in dp_ranks: + dp_pg = new_group + + pp_pg: dist.ProcessGroup + ranks_with_pp_last = ranks.transpose((2, 1, 0)).reshape( + (self.tensor_parallel_size * self.data_parallel_size, self.pipeline_parallel_size) + ) + for pp_ranks in ranks_with_pp_last: + sorted_ranks = tuple(sorted(pp_ranks)) + if sorted_ranks not in world_ranks_to_pg: + new_group = dist.new_group(ranks=pp_ranks) + world_ranks_to_pg[sorted_ranks] = new_group + else: + new_group = world_ranks_to_pg[sorted_ranks] + if rank in pp_ranks: + pp_pg = new_group + + # TODO(xrsrke): this looks unnecessary, remove it if possible + # We build model parallel group (combination of both tensor parallel and pipeline parallel) + for dp_rank in range(self.data_parallel_size): + pp_and_tp_ranks = ranks[:, dp_rank, :].reshape(-1) + sorted_ranks = tuple(sorted(pp_and_tp_ranks)) + if sorted_ranks not in world_ranks_to_pg: + new_group = dist.new_group(ranks=pp_and_tp_ranks) + world_ranks_to_pg[sorted_ranks] = new_group + + self.tp_pg = tp_pg + self.dp_pg = dp_pg + self.pp_pg = pp_pg + + # parallel_mode_to_pg = { + # ParallelMode.TENSOR: tp_pg, + # ParallelMode.PIPELINE: pp_pg, + # ParallelMode.DATA: dp_pg, + # } + # for parallel_mode in [ParallelMode.TENSOR, ParallelMode.PIPELINE, ParallelMode.DATA]: + # process_group = parallel_mode_to_pg[parallel_mode] + # # self.add_local_rank(parallel_mode, dist.get_rank(process_group)) + # # self.add_world_size(parallel_mode, dist.get_world_size(process_group)) + # self.add_group(parallel_mode, process_group) + # # self.add_ranks_in_group(parallel_mode, dist.get_process_group_ranks(process_group)) + + # TODO(xrsrke): remove world_rank_matrix, world_ranks_to_pg + self.world_rank_matrix = ranks + self.world_ranks_to_pg = world_ranks_to_pg + + dist.barrier() + + # def _register_dist( + # self, + # local_rank: int, + # local_world_size: int, + # process_group: dist.ProcessGroup, + # ranks_in_group: List[int], + # parallel_mode: ParallelMode, + # ): + # """Register distributed group based on the parallel mode. + + # Args: + # local_rank (int): local rank + # local_world_size (int): local world size + # mode (ParallelMode): parallel mode + # """ + # self.add_local_rank(parallel_mode, local_rank) + # self.add_world_size(parallel_mode, local_world_size) + # self.add_group(parallel_mode, process_group) + # self.add_ranks_in_group(parallel_mode, ranks_in_group) + + def set_device(self): + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + # NOTE: Set the device id. + # `torch.cuda.device_count` should return the number of device on a single node. + # We assume the nodes to be homogeneous (same number of gpus per node) + device_id = local_rank + torch.cuda.set_device(torch.cuda.device(device_id)) + + def map_rank_to_device(self): + """Map global rank to device.""" + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + # NOTE: Set the device id. + # `torch.cuda.device_count` should return the number of device on a single node. + # We assume the nodes to be homogeneous (same number of gpus per node) + device_id = local_rank + torch.cuda.set_device(torch.cuda.device(device_id)) + + def is_initialized(self, parallel_mode: ParallelMode) -> bool: + """Check if the parallel mode is initialized. + + Args: + mode (ParallelMode): parallel mode + + Returns: + bool: True if the parallel mode is initialized, False otherwise + """ + return True if parallel_mode in self._groups else False + + # def get_global_rank(self) -> int: + # """Get the global rank of the local process.""" + # return self._global_ranks[ParallelMode.GLOBAL] + + # def add_global_rank(self, parallel_mode: ParallelMode, rank: int): + # """Add the global rank of the local process.""" + # self._global_ranks[parallel_mode] = rank + + # def get_local_rank(self, parallel_mode: ParallelMode) -> int: + # """Get the local rank of the local process in a given parallel mode.""" + # return self._local_ranks[parallel_mode] + + # def add_local_rank(self, parallel_mode: ParallelMode, rank: int): + # """Add the local rank of the local process in a given parallel mode.""" + # self._local_ranks[parallel_mode] = rank + + def get_global_rank_from_local_rank(self, local_rank: int, parallel_mode: ParallelMode) -> int: + """Get the global rank from a local rank in a given parallel mode.""" + process_group = self.get_group(parallel_mode) + return dist.get_global_rank(process_group, local_rank) + + # # TODO(xrsrke): add cache + # def get_world_size(self, parallel_mode: ParallelMode) -> int: + # """Get the world size of a given parallel mode.""" + # return self._world_sizes[parallel_mode] + + # def add_world_size(self, parallel_mode: ParallelMode, world_size: int): + # """Add the world size of a given parallel mode.""" + # self._world_sizes[parallel_mode] = world_size + + def add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup) -> int: + """Add a process group of a given parallel mode.""" + self._groups[parallel_mode] = group + + # TODO(xrsrke): add cache + def get_group(self, parallel_mode: ParallelMode) -> dist.ProcessGroup: + """Get a process group of a given parallel mode.""" + return self._groups[parallel_mode] + + # def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks_in_group: List[int]): + # """Add a list of global ranks in a given parallel mode of the local process.""" + # self._ranks_in_group[parallel_mode] = ranks_in_group + + # def get_ranks_in_group(self, parallel_mode: ParallelMode) -> List[int]: + # """A list of global ranks in a given parallel mode of the local process.""" + # return self._ranks_in_group[parallel_mode] + + # def get_next_local_rank(self, rank, parallel_mode: ParallelMode) -> int: + # """Get the next local rank in a given parallel mode.""" + # world_size = self.get_world_size(parallel_mode) + # return (rank + 1) % world_size + + # def get_prev_local_rank(self, rank, parallel_mode: ParallelMode) -> int: + # """Get the previous local rank in a given parallel mode.""" + # world_size = self.get_world_size(parallel_mode) + # return (rank - 1) % world_size + + # def is_first_rank(self, parallel_mode: ParallelMode) -> bool: + # local_rank = self.get_local_rank(parallel_mode) + # return local_rank == 0 + + # def is_last_rank(self, parallel_mode: ParallelMode) -> bool: + # local_rank = self.get_local_rank(parallel_mode) + # world_size = self.get_world_size(parallel_mode) + # return local_rank == world_size - 1 + + def get_3d_ranks(self, local_rank: int, parallel_mode: ParallelMode = ParallelMode.GLOBAL) -> Tuple[int, int, int]: + rank = self.get_global_rank_from_local_rank(local_rank, parallel_mode) + tp_world_size = self.get_world_size(ParallelMode.TENSOR) + dp_world_size = self.get_world_size(ParallelMode.DATA) + pp_world_size = self.get_world_size(ParallelMode.PIPELINE) + + pp_rank = (rank // (tp_world_size * dp_world_size)) % pp_world_size + dp_rank = (rank // tp_world_size) % dp_world_size + tp_rank = rank % tp_world_size + return (pp_rank, dp_rank, tp_rank) + + def destroy(self): + assert self.is_initialized(ParallelMode.GLOBAL), "Global group must be initialized before destroying." + for mode, group in self._groups.items(): + # NOTE: we destroy the global group last + if mode is not ParallelMode.GLOBAL: + if self.is_initialized(mode) and self.get_world_size(mode) > 1: + # NOTE: only ranks in the parallel group need to synchronize + # before destroying the group + group = self.get_group(mode) + dist.barrier(group=group) + dist.destroy_process_group(group) + + dist.barrier() + dist.destroy_process_group() + + self._groups.clear() diff --git a/src/nanotron/distributed/parallel_mode.py b/src/nanotron/distributed/parallel_mode.py new file mode 100644 index 00000000..45419977 --- /dev/null +++ b/src/nanotron/distributed/parallel_mode.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class ParallelMode(Enum): + GLOBAL = "global" + + TENSOR = "tensor" + PIPELINE = "pipeline" + DATA = "data" diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index ac3d755b..6ca3e81a 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple import torch.cuda -from nanotron.core.process_groups import get_process_groups +from nanotron.distributed import ParallelContext from torch.distributed.launcher import elastic_launch @@ -72,14 +72,20 @@ def __init__(self, func, args, kwargs, tp: int, dp: int, pp: int): def __call__(self): with mock_os_environ(update_key_values={"WORLD_SIZE": f"{self.tp * self.dp * self.pp}"}): - dpg = get_process_groups( + # dpg = get_process_groups( + # data_parallel_size=self.dp, + # pipeline_parallel_size=self.pp, + # tensor_parallel_size=self.tp, + # ) + + parallel_context = ParallelContext( data_parallel_size=self.dp, pipeline_parallel_size=self.pp, tensor_parallel_size=self.tp, ) - assert "dpg" not in self.kwargs - self.kwargs["dpg"] = dpg + assert "parallel_context" not in self.kwargs + self.kwargs["parallel_context"] = parallel_context self.func(*self.args, **self.kwargs) diff --git a/tests/test_distributed.py b/tests/test_distributed.py new file mode 100644 index 00000000..506d8927 --- /dev/null +++ b/tests/test_distributed.py @@ -0,0 +1,62 @@ +import pytest +from helpers.utils import ( + available_gpus, + get_all_3d_configurations, + init_distributed, +) +from nanotron.distributed import ParallelContext +from torch.distributed import ProcessGroup + + +def _test_init_parallel_context(parallel_context: ParallelContext): + # parallel_modes = [ + # ParallelMode.GLOBAL, + # ParallelMode.TENSOR, + # ParallelMode.PIPELINE, + # ParallelMode.DATA, + # ] + + # assert isinstance(parallel_context.get_global_rank(), int) + + # for parallel_mode in parallel_modes: + # # local_rank = parallel_context.get_local_rank(parallel_mode) + + # assert parallel_context.is_initialized(parallel_mode) is True + # # assert isinstance(parallel_context.get_group(parallel_mode), ProcessGroup) + + # # assert type(parallel_context.get_local_rank(parallel_mode)) == int + # # assert type(parallel_context.get_world_size(parallel_mode)) == int + + # # process_group = parallel_context.get_group(parallel_mode) + # # assert isinstance(process_group, ProcessGroup) + # # ranks_in_group = parallel_context.get_ranks_in_group(parallel_mode) + # # TODO(xrsrke): do an expected list of ranks + # # assert ranks_in_group == dist.get_process_group_ranks(process_group) + # # assert len(ranks_in_group) == parallel_context.get_world_size(parallel_mode) + + # # assert parallel_context.is_first_rank(parallel_mode) == (local_rank == 0) + # # assert parallel_context.is_last_rank(parallel_mode) == ( + # # local_rank == parallel_context.get_world_size(parallel_mode) - 1 + # # ) + + assert isinstance(parallel_context.world_pg, ProcessGroup) + assert isinstance(parallel_context.tp_pg, ProcessGroup) if parallel_context.tensor_parallel_size > 1 else True + assert isinstance(parallel_context.pp_pg, ProcessGroup) if parallel_context.pipeline_parallel_size > 1 else True + assert isinstance(parallel_context.dp_pg, ProcessGroup) if parallel_context.data_parallel_size > 1 else True + + # parallel_context.destroy() + + # for parallel_mode in parallel_modes: + # assert parallel_context.is_initialized(parallel_mode) is False + + +@pytest.mark.parametrize( + "tp,dp,pp", + [ + pytest.param(*all_3d_configs) + for gpus in range(1, min(available_gpus(), 4) + 1) + for all_3d_configs in get_all_3d_configurations(gpus) + ], +) +def test_init_parallel_context(tp: int, dp: int, pp: int): + init_distributed(tp=tp, dp=dp, pp=pp)(_test_init_parallel_context)() From 3c8311824349af226d6554681dd947a09454faa2 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 10 Jan 2024 11:56:15 +0000 Subject: [PATCH 2/9] add ParallelContext to tests --- README.md | 14 +- src/nanotron/distributed/parallel_context.py | 21 +-- tests/test_clip_grads.py | 140 ++++++++++-------- tests/test_data_parallel.py | 16 +- tests/test_p2p.py | 6 +- ..._parameters_accumulate_gradient_in_fp32.py | 82 +++++----- tests/test_pipeline_parallel.py | 108 +++++++------- tests/test_random_state.py | 8 +- tests/test_serialize.py | 128 +++++++++------- tests/test_tensor_parallel.py | 122 +++++++++------ tests/test_tie_weights.py | 55 ++++--- tests/test_zero.py | 118 +++++++++------ 12 files changed, 462 insertions(+), 356 deletions(-) diff --git a/README.md b/README.md index a9b92b9c..22037c16 100644 --- a/README.md +++ b/README.md @@ -74,14 +74,18 @@ We showcase usage in the `examples` directory. Let's go through some key concepts. -## DistributedProcessGroups +## ParallelContext -`DistributedProcessGroups` is the base class referencing all the process groups you might need when running parallel workloads. You can initialize it using the following: +`ParallelContext` is the base class referencing all the process groups you might need when running parallel workloads. You can initialize it using the following: ```python -from nanotron.core.process_groups import get_process_groups +from nanotron.distributed import ParallelContext -dp, tp, pp = ... # Predefine your topology -dpg: DistributedProcessGroups = get_process_groups(data_parallel_size=dp, tensor_parallel_size=tp, pipeline_parallel_size=pp) +# define your topology +parallel_context = ParallelContext.from_torch( + tensor_parallel_size=2, + data_parallel_size=2, + pipeline_parallel_size=2 +) ``` `ProcessGroups` is a mechanism in order to run distributed collectives (`all-reduce`, `all-gather`, ...) on a subgroup of all the ranks. It provides the granularity needed for 3D parallelism. diff --git a/src/nanotron/distributed/parallel_context.py b/src/nanotron/distributed/parallel_context.py index 3166168d..f352649b 100644 --- a/src/nanotron/distributed/parallel_context.py +++ b/src/nanotron/distributed/parallel_context.py @@ -327,15 +327,18 @@ def get_group(self, parallel_mode: ParallelMode) -> dist.ProcessGroup: # world_size = self.get_world_size(parallel_mode) # return local_rank == world_size - 1 - def get_3d_ranks(self, local_rank: int, parallel_mode: ParallelMode = ParallelMode.GLOBAL) -> Tuple[int, int, int]: - rank = self.get_global_rank_from_local_rank(local_rank, parallel_mode) - tp_world_size = self.get_world_size(ParallelMode.TENSOR) - dp_world_size = self.get_world_size(ParallelMode.DATA) - pp_world_size = self.get_world_size(ParallelMode.PIPELINE) - - pp_rank = (rank // (tp_world_size * dp_world_size)) % pp_world_size - dp_rank = (rank // tp_world_size) % dp_world_size - tp_rank = rank % tp_world_size + def get_3d_ranks(self, world_rank: int) -> Tuple[int, int, int]: + # tp_world_size = self.get_world_size(ParallelMode.TENSOR) + # dp_world_size = self.get_world_size(ParallelMode.DATA) + # pp_world_size = self.get_world_size(ParallelMode.PIPELINE) + + # pp_rank = (world_rank // (tp_world_size * dp_world_size)) % pp_world_size + # dp_rank = (world_rank // tp_world_size) % dp_world_size + # tp_rank = world_rank % tp_world_size + # return (pp_rank, dp_rank, tp_rank) + pp_rank = (world_rank // (self.tp_pg.size() * self.dp_pg.size())) % self.pp_pg.size() + dp_rank = (world_rank // self.tp_pg.size()) % self.dp_pg.size() + tp_rank = world_rank % self.tp_pg.size() return (pp_rank, dp_rank, tp_rank) def destroy(self): diff --git a/tests/test_clip_grads.py b/tests/test_clip_grads.py index 72664602..08b020b1 100644 --- a/tests/test_clip_grads.py +++ b/tests/test_clip_grads.py @@ -24,8 +24,8 @@ sync_tied_weights_gradients, tie_parameters, ) -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.utils import assert_tensor_synced_across_pg, init_on_device_and_dtype +from nanotron.distributed import ParallelContext from torch import nn @@ -35,13 +35,13 @@ def test_clip_grads_with_pp(norm_type: float): init_distributed(tp=1, dp=1, pp=2)(_test_clip_grads_with_pp)(norm_type=norm_type) -def _test_clip_grads_with_pp(dpg: DistributedProcessGroups, norm_type: float): +def _test_clip_grads_with_pp(parallel_context: ParallelContext, norm_type: float): device = torch.device("cuda") - p2p = P2P(dpg.pp_pg, device=device) + p2p = P2P(parallel_context.pp_pg, device=device) reference_rank = 0 - has_reference_model = dist.get_rank(dpg.pp_pg) == reference_rank + has_reference_model = dist.get_rank(parallel_context.pp_pg) == reference_rank pipeline_engine = AllForwardAllBackwardPipelineEngine() - current_pp_rank = dist.get_rank(dpg.pp_pg) + current_pp_rank = dist.get_rank(parallel_context.pp_pg) # spawn model model = DummyModel(p2p=p2p) @@ -49,12 +49,12 @@ def _test_clip_grads_with_pp(dpg: DistributedProcessGroups, norm_type: float): reference_model = DummyModel(p2p=p2p) # Set the ranks - assert len(model.mlp) == dpg.pp_pg.size() + assert len(model.mlp) == parallel_context.pp_pg.size() with init_on_device_and_dtype(device): - for pp_rank, non_linear in zip(range(dpg.pp_pg.size()), model.mlp): + for pp_rank, non_linear in zip(range(parallel_context.pp_pg.size()), model.mlp): non_linear.linear.build_and_set_rank(pp_rank=pp_rank) non_linear.activation.build_and_set_rank(pp_rank=pp_rank) - model.loss.build_and_set_rank(pp_rank=dpg.pp_pg.size() - 1) + model.loss.build_and_set_rank(pp_rank=parallel_context.pp_pg.size() - 1) # build reference model if has_reference_model: @@ -71,7 +71,7 @@ def _test_clip_grads_with_pp(dpg: DistributedProcessGroups, norm_type: float): # synchronize weights if has_reference_model: with torch.inference_mode(): - for pp_rank in range(dpg.pp_pg.size()): + for pp_rank in range(parallel_context.pp_pg.size()): reference_non_linear = reference_model.mlp[pp_rank].linear.pp_block if pp_rank == current_pp_rank: # We already have the weights locally @@ -90,12 +90,12 @@ def _test_clip_grads_with_pp(dpg: DistributedProcessGroups, norm_type: float): ) # Get infinite dummy data iterator - data_iterator = dummy_infinite_data_loader(pp_pg=dpg.pp_pg) # First rank receives data + data_iterator = dummy_infinite_data_loader(pp_pg=parallel_context.pp_pg) # First rank receives data n_micro_batches_per_batch = 5 batch = [next(data_iterator) for _ in range(n_micro_batches_per_batch)] pipeline_engine.train_batch_iter( - model, pg=dpg.pp_pg, batch=batch, nb_microbatches=n_micro_batches_per_batch, grad_accumulator=None + model, pg=parallel_context.pp_pg, batch=batch, nb_microbatches=n_micro_batches_per_batch, grad_accumulator=None ) # Equivalent on the reference model @@ -106,9 +106,9 @@ def _test_clip_grads_with_pp(dpg: DistributedProcessGroups, norm_type: float): loss.backward() # Check that gradient are the same as reference - pp_rank = dist.get_rank(dpg.pp_pg) + pp_rank = dist.get_rank(parallel_context.pp_pg) if has_reference_model: - for pp_rank in range(dpg.pp_pg.size()): + for pp_rank in range(parallel_context.pp_pg.size()): reference_non_linear = reference_model.mlp[pp_rank].linear.pp_block if pp_rank == current_pp_rank: # We already have the gradients locally @@ -119,12 +119,7 @@ def _test_clip_grads_with_pp(dpg: DistributedProcessGroups, norm_type: float): atol=1e-6, rtol=1e-7, ) - torch.testing.assert_close( - non_linear.bias.grad, - reference_non_linear.bias.grad, - atol=1e-6, - rtol=1e-7 - ) + torch.testing.assert_close(non_linear.bias.grad, reference_non_linear.bias.grad, atol=1e-6, rtol=1e-7) continue weight_grad, bias_grad = p2p.recv_tensors(num_tensors=2, from_rank=pp_rank) @@ -141,7 +136,9 @@ def _test_clip_grads_with_pp(dpg: DistributedProcessGroups, norm_type: float): old_bias_grad = non_linear.bias.grad.clone() # Clip grads total_norm = clip_grad_norm( - mp_pg=dpg.world_ranks_to_pg[tuple(sorted(dpg.world_rank_matrix[:, dist.get_rank(dpg.dp_pg), :].reshape(-1)))], + mp_pg=parallel_context.world_ranks_to_pg[ + tuple(sorted(parallel_context.world_rank_matrix[:, dist.get_rank(parallel_context.dp_pg), :].reshape(-1))) + ], named_parameters=model.named_parameters(), grad_accumulator=None, max_norm=1.0, @@ -159,7 +156,7 @@ def _test_clip_grads_with_pp(dpg: DistributedProcessGroups, norm_type: float): # Check that gradient are the same as reference if has_reference_model: - for pp_rank in range(dpg.pp_pg.size()): + for pp_rank in range(parallel_context.pp_pg.size()): reference_non_linear = reference_model.mlp[pp_rank].linear.pp_block if pp_rank == current_pp_rank: # We already have the gradients locally @@ -202,19 +199,19 @@ def test_clip_grads_with_tp(tp_mode: TensorParallelLinearMode, async_communicati def _test_clip_grads_with_tp( - dpg: DistributedProcessGroups, tp_mode: TensorParallelLinearMode, async_communication: bool, norm_type: float + parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool, norm_type: float ): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" in_features = 2 out_features_per_tp_rank = 3 - out_features = dpg.tp_pg.size() * out_features_per_tp_rank + out_features = parallel_context.tp_pg.size() * out_features_per_tp_rank # Sharded column_linear = TensorParallelColumnLinear( in_features=in_features, out_features=out_features, - pg=dpg.tp_pg, + pg=parallel_context.tp_pg, mode=tp_mode, device="cuda", async_communication=async_communication, @@ -228,12 +225,12 @@ def _test_clip_grads_with_tp( dist.all_gather( tensor_list=list(reference_linear.weight.split(out_features_per_tp_rank, dim=0)), tensor=column_linear.weight, - group=dpg.tp_pg, + group=parallel_context.tp_pg, ) dist.all_gather( tensor_list=list(reference_linear.bias.split(out_features_per_tp_rank, dim=0)), tensor=column_linear.bias, - group=dpg.tp_pg, + group=parallel_context.tp_pg, ) # Generate random input @@ -243,18 +240,18 @@ def _test_clip_grads_with_tp( batch_size = 5 random_input = torch.randn(batch_size, in_features, device="cuda") # synchronize random_input across tp - dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=dpg.tp_pg) + dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg) sharded_random_input = random_input elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: sharded_batch_size = 5 sharded_random_input = torch.randn(sharded_batch_size, in_features, device="cuda") random_input = torch.empty( - sharded_batch_size * dpg.tp_pg.size(), + sharded_batch_size * parallel_context.tp_pg.size(), *(sharded_random_input.shape[1:]), device=sharded_random_input.device, dtype=sharded_random_input.dtype, ) - dist.all_gather_into_tensor(random_input, sharded_random_input, group=dpg.tp_pg) + dist.all_gather_into_tensor(random_input, sharded_random_input, group=parallel_context.tp_pg) else: ValueError(f"Unsupported mode: {tp_mode}") @@ -266,8 +263,8 @@ def _test_clip_grads_with_tp( sharded_output, reference_output[ :, - dist.get_rank(dpg.tp_pg) - * out_features_per_tp_rank : (dist.get_rank(dpg.tp_pg) + 1) + dist.get_rank(parallel_context.tp_pg) + * out_features_per_tp_rank : (dist.get_rank(parallel_context.tp_pg) + 1) * out_features_per_tp_rank, ], atol=1e-6, @@ -280,8 +277,8 @@ def _test_clip_grads_with_tp( torch.testing.assert_close( column_linear.weight.grad, reference_linear.weight.grad[ - dist.get_rank(dpg.tp_pg) - * out_features_per_tp_rank : (dist.get_rank(dpg.tp_pg) + 1) + dist.get_rank(parallel_context.tp_pg) + * out_features_per_tp_rank : (dist.get_rank(parallel_context.tp_pg) + 1) * out_features_per_tp_rank ], atol=1e-6, @@ -290,8 +287,8 @@ def _test_clip_grads_with_tp( torch.testing.assert_close( column_linear.bias.grad, reference_linear.bias.grad[ - dist.get_rank(dpg.tp_pg) - * out_features_per_tp_rank : (dist.get_rank(dpg.tp_pg) + 1) + dist.get_rank(parallel_context.tp_pg) + * out_features_per_tp_rank : (dist.get_rank(parallel_context.tp_pg) + 1) * out_features_per_tp_rank ], atol=1e-6, @@ -301,7 +298,9 @@ def _test_clip_grads_with_tp( old_grad = column_linear.weight.grad.clone() # Clip grads total_norm = clip_grad_norm( - mp_pg=dpg.world_ranks_to_pg[tuple(sorted(dpg.world_rank_matrix[:, dist.get_rank(dpg.dp_pg), :].reshape(-1)))], + mp_pg=parallel_context.world_ranks_to_pg[ + tuple(sorted(parallel_context.world_rank_matrix[:, dist.get_rank(parallel_context.dp_pg), :].reshape(-1))) + ], named_parameters=column_linear.named_parameters(), grad_accumulator=None, max_norm=1.0, @@ -316,16 +315,16 @@ def _test_clip_grads_with_tp( torch.testing.assert_close( column_linear.weight.grad, reference_linear.weight.grad[ - dist.get_rank(dpg.tp_pg) - * out_features_per_tp_rank : (dist.get_rank(dpg.tp_pg) + 1) + dist.get_rank(parallel_context.tp_pg) + * out_features_per_tp_rank : (dist.get_rank(parallel_context.tp_pg) + 1) * out_features_per_tp_rank ], ) torch.testing.assert_close( column_linear.bias.grad, reference_linear.bias.grad[ - dist.get_rank(dpg.tp_pg) - * out_features_per_tp_rank : (dist.get_rank(dpg.tp_pg) + 1) + dist.get_rank(parallel_context.tp_pg) + * out_features_per_tp_rank : (dist.get_rank(parallel_context.tp_pg) + 1) * out_features_per_tp_rank ], ) @@ -338,8 +337,8 @@ def test_clip_grads_tied_weights(norm_type: float): init_distributed(tp=1, dp=1, pp=2)(_test_clip_grads_tied_weights)(norm_type=norm_type) -def _test_clip_grads_tied_weights(dpg: DistributedProcessGroups, norm_type: float): - if dist.get_rank(dpg.pp_pg) == 0: +def _test_clip_grads_tied_weights(parallel_context: ParallelContext, norm_type: float): + if dist.get_rank(parallel_context.pp_pg) == 0: model = nn.ModuleDict( { "dense0": nn.Linear(10, 10, device="cuda"), @@ -356,17 +355,20 @@ def _test_clip_grads_tied_weights(dpg: DistributedProcessGroups, norm_type: floa tie_parameters( root_module=model, ties=[("dense0.weight", (0,)), ("dense1.weight", (1,))], - dpg=dpg, + dpg=parallel_context, reduce_op=dist.ReduceOp.SUM, ) tie_parameters( - root_module=model, ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))], dpg=dpg, reduce_op=dist.ReduceOp.SUM + root_module=model, + ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))], + dpg=parallel_context, + reduce_op=dist.ReduceOp.SUM, ) - group = dpg.world_ranks_to_pg[(0, 1)] + group = parallel_context.world_ranks_to_pg[(0, 1)] # Check that model weights are not in fact synchronized - if dist.get_rank(dpg.pp_pg) == 0: + if dist.get_rank(parallel_context.pp_pg) == 0: weight = model.dense0.weight bias = model.dense0.bias else: @@ -380,7 +382,7 @@ def _test_clip_grads_tied_weights(dpg: DistributedProcessGroups, norm_type: floa assert bias.is_tied # Sync tied weights: basic assumption - initial_sync(model=model, dpg=dpg) + initial_sync(model=model, dpg=parallel_context) # Check that weights are now synced assert_tensor_synced_across_pg(weight, group) @@ -388,14 +390,14 @@ def _test_clip_grads_tied_weights(dpg: DistributedProcessGroups, norm_type: floa # Compute gradient input_ = torch.randn(13, 10, device="cuda") - if dist.get_rank(dpg.pp_pg) == 0: + if dist.get_rank(parallel_context.pp_pg) == 0: out = model.dense0(input_) else: out = model.dense1(input_) out.sum().backward() # sync gradients - sync_tied_weights_gradients(model, dpg=dpg, grad_accumulator=None) + sync_tied_weights_gradients(model, dpg=parallel_context, grad_accumulator=None) # We check that we both gradients are synchronized assert_tensor_synced_across_pg(weight.grad, group) @@ -410,7 +412,9 @@ def _test_clip_grads_tied_weights(dpg: DistributedProcessGroups, norm_type: floa old_grad = weight.grad.clone() # Clip grads total_norm = clip_grad_norm( - mp_pg=dpg.world_ranks_to_pg[tuple(sorted(dpg.world_rank_matrix[:, dist.get_rank(dpg.dp_pg), :].reshape(-1)))], + mp_pg=parallel_context.world_ranks_to_pg[ + tuple(sorted(parallel_context.world_rank_matrix[:, dist.get_rank(parallel_context.dp_pg), :].reshape(-1))) + ], named_parameters=model.named_parameters(), grad_accumulator=None, max_norm=1.0, @@ -435,13 +439,15 @@ def test_clip_grads_fp32_accumulator(norm_type: float, half_precision: torch.dty ) -def _test_clip_grads_fp32_accumulator(dpg: DistributedProcessGroups, norm_type: float, half_precision: torch.dtype): +def _test_clip_grads_fp32_accumulator( + parallel_context: ParallelContext, norm_type: float, half_precision: torch.dtype +): device = torch.device("cuda") - p2p = P2P(dpg.pp_pg, device=device) + p2p = P2P(parallel_context.pp_pg, device=device) reference_rank = 0 - has_reference_model = dist.get_rank(dpg.pp_pg) == reference_rank + has_reference_model = dist.get_rank(parallel_context.pp_pg) == reference_rank pipeline_engine = AllForwardAllBackwardPipelineEngine() - current_pp_rank = dist.get_rank(dpg.pp_pg) + current_pp_rank = dist.get_rank(parallel_context.pp_pg) # spawn model model = DummyModel(p2p=p2p) @@ -449,12 +455,12 @@ def _test_clip_grads_fp32_accumulator(dpg: DistributedProcessGroups, norm_type: reference_model = DummyModel(p2p=p2p).to(torch.float) # Set the ranks - assert len(model.mlp) == dpg.pp_pg.size() + assert len(model.mlp) == parallel_context.pp_pg.size() with init_on_device_and_dtype(device): - for pp_rank, non_linear in zip(range(dpg.pp_pg.size()), model.mlp): + for pp_rank, non_linear in zip(range(parallel_context.pp_pg.size()), model.mlp): non_linear.linear.build_and_set_rank(pp_rank=pp_rank) non_linear.activation.build_and_set_rank(pp_rank=pp_rank) - model.loss.build_and_set_rank(pp_rank=dpg.pp_pg.size() - 1) + model.loss.build_and_set_rank(pp_rank=parallel_context.pp_pg.size() - 1) if has_reference_model: for non_linear in reference_model.mlp: @@ -473,7 +479,7 @@ def _test_clip_grads_fp32_accumulator(dpg: DistributedProcessGroups, norm_type: # synchronize weights if has_reference_model: with torch.inference_mode(): - for pp_rank in range(dpg.pp_pg.size()): + for pp_rank in range(parallel_context.pp_pg.size()): reference_non_linear = reference_model.mlp[pp_rank].linear.pp_block if pp_rank == current_pp_rank: # We already have the weights locally @@ -499,18 +505,24 @@ def _test_clip_grads_fp32_accumulator(dpg: DistributedProcessGroups, norm_type: # Compute backward # Get infinite dummy data iterator - data_iterator = dummy_infinite_data_loader(pp_pg=dpg.pp_pg, dtype=half_precision) # First rank receives data + data_iterator = dummy_infinite_data_loader( + pp_pg=parallel_context.pp_pg, dtype=half_precision + ) # First rank receives data n_micro_batches_per_batch = 5 batch = [next(data_iterator) for _ in range(n_micro_batches_per_batch)] pipeline_engine.train_batch_iter( - model, pg=dpg.pp_pg, batch=batch, nb_microbatches=n_micro_batches_per_batch, grad_accumulator=grad_accumulator + model, + pg=parallel_context.pp_pg, + batch=batch, + nb_microbatches=n_micro_batches_per_batch, + grad_accumulator=grad_accumulator, ) # We're going to copy the model gradients to the reference model gradient # The reason why we do this, instead of computing backward using autograd is because of numerical precisions if has_reference_model: - for pp_rank in range(dpg.pp_pg.size()): + for pp_rank in range(parallel_context.pp_pg.size()): reference_non_linear = reference_model.mlp[pp_rank].linear.pp_block prefix_name = f"mlp.{pp_rank}.linear.pp_block" if pp_rank == current_pp_rank: @@ -537,7 +549,9 @@ def _test_clip_grads_fp32_accumulator(dpg: DistributedProcessGroups, norm_type: # Clip grads total_norm = clip_grad_norm( - mp_pg=dpg.world_ranks_to_pg[tuple(sorted(dpg.world_rank_matrix[:, dist.get_rank(dpg.dp_pg), :].reshape(-1)))], + mp_pg=parallel_context.world_ranks_to_pg[ + tuple(sorted(parallel_context.world_rank_matrix[:, dist.get_rank(parallel_context.dp_pg), :].reshape(-1))) + ], named_parameters=model.named_parameters(), grad_accumulator=grad_accumulator, max_norm=1.0, @@ -562,7 +576,7 @@ def _test_clip_grads_fp32_accumulator(dpg: DistributedProcessGroups, norm_type: rtol=1e-7, msg=lambda msg: f"Expected {total_norm} to match {ref_total_norm}.\n{msg}", ) - for pp_rank in range(dpg.pp_pg.size()): + for pp_rank in range(parallel_context.pp_pg.size()): reference_non_linear = reference_model.mlp[pp_rank].linear.pp_block prefix_name = f"mlp.{pp_rank}.linear.pp_block" if pp_rank == current_pp_rank: diff --git a/tests/test_data_parallel.py b/tests/test_data_parallel.py index 2fdcef22..4072a22b 100644 --- a/tests/test_data_parallel.py +++ b/tests/test_data_parallel.py @@ -7,8 +7,8 @@ from nanotron.core import distributed as dist from nanotron.core.parallel.data_parallelism.utils import ddp_trigger_sync_in_bwd from nanotron.core.parallel.parameters import NanotronParameter -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.utils import assert_tensor_synced_across_pg +from nanotron.distributed import ParallelContext from torch import nn from torch.distributed import GradBucket @@ -19,15 +19,15 @@ def test_ddp_with_afab(accumulation_steps): init_distributed(tp=1, dp=2, pp=1)(_test_ddp_with_afab)(accumulation_steps=accumulation_steps) -def _test_ddp_with_afab(dpg: DistributedProcessGroups, accumulation_steps: int): - dist.get_rank(dpg.dp_pg) +def _test_ddp_with_afab(parallel_context: ParallelContext, accumulation_steps: int): + dist.get_rank(parallel_context.dp_pg) half_precision = torch.float16 def allreduce_hook(process_group: dist.ProcessGroup, bucket: GradBucket): # DDP groups grads in GradBuckets. This hook is called throughout the bwd pass, once each bucket is ready to overlap communication with computation. # See https://pytorch.org/docs/stable/ddp_comm_hooks.html#what-does-a-communication-hook-operate-on for more details. half_flat_bucket_buffer = bucket.buffer() - group_to_use = process_group if process_group is not None else dpg.dp_pg + group_to_use = process_group if process_group is not None else parallel_context.dp_pg return ( dist.all_reduce(half_flat_bucket_buffer, group=group_to_use, async_op=True, op=dist.ReduceOp.AVG) @@ -42,7 +42,7 @@ def allreduce_hook(process_group: dist.ProcessGroup, bucket: GradBucket): model_ddp_hook = torch.nn.parallel.DistributedDataParallel( model_hook, - process_group=dpg.dp_pg, + process_group=parallel_context.dp_pg, ) # Register DDP hook @@ -71,7 +71,7 @@ def allreduce_hook(process_group: dist.ProcessGroup, bucket: GradBucket): # Check that the gradients are synchronized across DP if i == accumulation_steps - 1: - assert_tensor_synced_across_pg(grad_hook, dpg.dp_pg) + assert_tensor_synced_across_pg(grad_hook, parallel_context.dp_pg) else: - with assert_fail_except_rank_with(AssertionError, rank_exception=0, pg=dpg.dp_pg): - assert_tensor_synced_across_pg(grad_hook, dpg.dp_pg) + with assert_fail_except_rank_with(AssertionError, rank_exception=0, pg=parallel_context.dp_pg): + assert_tensor_synced_across_pg(grad_hook, parallel_context.dp_pg) diff --git a/tests/test_p2p.py b/tests/test_p2p.py index 34ec20c7..2ac66eda 100644 --- a/tests/test_p2p.py +++ b/tests/test_p2p.py @@ -6,7 +6,7 @@ from helpers.utils import available_gpus, init_distributed from nanotron.core import distributed as dist from nanotron.core.parallel.pipeline_parallelism.p2p import P2P -from nanotron.core.process_groups import DistributedProcessGroups +from nanotron.distributed import ParallelContext @pytest.mark.skipif(available_gpus() < 2, reason="Testing test_ddp_with_afab requires at least 2 gpus") @@ -16,8 +16,8 @@ def test_check_send_recv_tensor(send_contiguous: bool, full: bool): init_distributed(tp=1, dp=1, pp=2)(_test_check_send_recv_tensor)(send_contiguous=send_contiguous, full=full) -def _test_check_send_recv_tensor(dpg: DistributedProcessGroups, send_contiguous: bool, full: bool): - p2p = P2P(pg=dpg.pp_pg, device=torch.device("cuda")) +def _test_check_send_recv_tensor(parallel_context: ParallelContext, send_contiguous: bool, full: bool): + p2p = P2P(pg=parallel_context.pp_pg, device=torch.device("cuda")) if dist.get_rank(p2p.pg) == 0: tensor_to_send = torch.randn(3, 5, dtype=torch.float, device=torch.device("cuda")) if send_contiguous is True: diff --git a/tests/test_parameters_accumulate_gradient_in_fp32.py b/tests/test_parameters_accumulate_gradient_in_fp32.py index 96638fda..0f4333d5 100644 --- a/tests/test_parameters_accumulate_gradient_in_fp32.py +++ b/tests/test_parameters_accumulate_gradient_in_fp32.py @@ -26,8 +26,8 @@ sync_tied_weights_gradients, tie_parameters, ) -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.utils import ContextManagers, assert_tensor_synced_across_pg, init_on_device_and_dtype +from nanotron.distributed import ParallelContext from torch import nn @@ -148,7 +148,7 @@ def test_ddp_with_grad_accum_in_fp32(half_precision: torch.dtype, accumulation_s def _test_ddp_with_grad_accum_in_fp32( - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, half_precision: torch.dtype, accumulation_steps: int, train_iterations: int, @@ -176,12 +176,12 @@ def _test_ddp_with_grad_accum_in_fp32( # Needed in order to obtain smaller gradient buckets when using `DistributedDataParallel` model_ddp = torch.nn.parallel.DistributedDataParallel( model, - process_group=dpg.dp_pg, + process_group=parallel_context.dp_pg, ) # we won't actually use DDP anywhere, it's just to have same module names model_ddp_accum_ref = {} model_ddp_fp32_accum = torch.nn.parallel.DistributedDataParallel( model_hook, - process_group=dpg.dp_pg, + process_group=parallel_context.dp_pg, ) # Add gradient accumulator @@ -189,7 +189,7 @@ def _test_ddp_with_grad_accum_in_fp32( # Register DDP hook state = FP32GradBucketManager( - dp_pg=dpg.dp_pg, + dp_pg=parallel_context.dp_pg, accumulator=accumulator, param_id_to_name={id(param): name for name, param in model_ddp_fp32_accum.named_parameters()}, ) @@ -222,19 +222,19 @@ def _test_ddp_with_grad_accum_in_fp32( model_ddp_accum_ref[name] = ( grad.float() if accum_step == 0 else model_ddp_accum_ref[name] + grad.float() ) - + dist.barrier() torch.testing.assert_close(model_ddp_accum_ref[name], fp32_grad_bucket, atol=1e-6, rtol=1e-7) - + dist.barrier() # Check that we correctly copied grads from buckets to params (`copy_buckets_to_grads`) torch.testing.assert_close(fp32_grad_bucket, grad_fp32_accum, atol=1e-6, rtol=1e-7) # Check that the gradients are not synchronized across DP - with assert_fail_except_rank_with(AssertionError, rank_exception=0, pg=dpg.dp_pg): - assert_tensor_synced_across_pg(grad, dpg.dp_pg) - with assert_fail_except_rank_with(AssertionError, rank_exception=0, pg=dpg.dp_pg): - assert_tensor_synced_across_pg(fp32_grad_bucket, dpg.dp_pg) + with assert_fail_except_rank_with(AssertionError, rank_exception=0, pg=parallel_context.dp_pg): + assert_tensor_synced_across_pg(grad, parallel_context.dp_pg) + with assert_fail_except_rank_with(AssertionError, rank_exception=0, pg=parallel_context.dp_pg): + assert_tensor_synced_across_pg(fp32_grad_bucket, parallel_context.dp_pg) # We zero out half grads for `model_ddp` because we're accumulating grads manually in `model_ddp_accum_ref` model_ddp.zero_grad() @@ -249,7 +249,7 @@ def _test_ddp_with_grad_accum_in_fp32( model_ddp_accum_ref[name] = ( model_ddp_accum_ref[name] + grad.float() if name in model_ddp_accum_ref else grad.float() ) - dist.all_reduce(model_ddp_accum_ref[name], group=dpg.dp_pg, op=dist.ReduceOp.AVG) + dist.all_reduce(model_ddp_accum_ref[name], group=parallel_context.dp_pg, op=dist.ReduceOp.AVG) loss_fp32_accum = model_ddp_fp32_accum(input).sum() accumulator.backward(loss_fp32_accum) @@ -271,8 +271,8 @@ def _test_ddp_with_grad_accum_in_fp32( assert grad_fp32_accum.data_ptr() == fp32_grad_bucket.data_ptr() # Check that the gradients are synchronized across DP - assert_tensor_synced_across_pg(grad, dpg.dp_pg) - assert_tensor_synced_across_pg(grad_fp32_accum, dpg.dp_pg) + assert_tensor_synced_across_pg(grad, parallel_context.dp_pg) + assert_tensor_synced_across_pg(grad_fp32_accum, parallel_context.dp_pg) # Zero out gradients (Usually it's the optimizer that does this) model_ddp.zero_grad() @@ -311,12 +311,12 @@ def test_tied_weights_sync_with_grad_accum_in_fp32(pipeline_engine: PipelineEngi def _test_tied_weights_sync_with_grad_accum_in_fp32( - dpg: DistributedProcessGroups, pipeline_engine: PipelineEngine, reduce_scatter: bool + parallel_context: ParallelContext, pipeline_engine: PipelineEngine, reduce_scatter: bool ): # We init two replicas of 2 denses. Each dense is on a device. dtype = torch.float16 device = torch.device("cuda") - p2p = P2P(pg=dpg.pp_pg, device=device) + p2p = P2P(pg=parallel_context.pp_pg, device=device) model = DummyModel(p2p=p2p) reference_model = DummyModel(p2p=p2p) @@ -325,11 +325,11 @@ def _test_tied_weights_sync_with_grad_accum_in_fp32( for mdl in [model, reference_model]: # Set the ranks with init_on_device_and_dtype(device, dtype): - assert dpg.pp_pg.size() == len(mdl.mlp) - for pp_rank, non_linear in zip(range(dpg.pp_pg.size()), mdl.mlp): + assert parallel_context.pp_pg.size() == len(mdl.mlp) + for pp_rank, non_linear in zip(range(parallel_context.pp_pg.size()), mdl.mlp): non_linear.linear.build_and_set_rank(pp_rank=pp_rank) non_linear.activation.build_and_set_rank(pp_rank=pp_rank) - mdl.loss.build_and_set_rank(pp_rank=dpg.pp_pg.size() - 1) + mdl.loss.build_and_set_rank(pp_rank=parallel_context.pp_pg.size() - 1) # Tie all dense weights across PP tie_parameters( @@ -338,14 +338,18 @@ def _test_tied_weights_sync_with_grad_accum_in_fp32( ( target, ( - dpg.world_rank_matrix[ - get_pp_rank_of(target, module=mdl), dist.get_rank(dpg.dp_pg), dist.get_rank(dpg.tp_pg) + parallel_context.world_rank_matrix[ + get_pp_rank_of(target, module=mdl), + dist.get_rank(parallel_context.dp_pg), + dist.get_rank(parallel_context.tp_pg), ], ), ) - for target in [f"mlp.{pp_rank}.linear.pp_block.weight" for pp_rank in range(dpg.pp_pg.size())] + for target in [ + f"mlp.{pp_rank}.linear.pp_block.weight" for pp_rank in range(parallel_context.pp_pg.size()) + ] ], - dpg=dpg, + dpg=parallel_context, reduce_op=dist.ReduceOp.SUM, ) @@ -354,7 +358,7 @@ def _test_tied_weights_sync_with_grad_accum_in_fp32( module.bias = NanotronParameter(module.bias) # Sync DP and tied weights: basic assumption - initial_sync(model=mdl, dpg=dpg) + initial_sync(model=mdl, dpg=parallel_context) # Sync params between `model` and `reference_model` with torch.no_grad(): @@ -362,7 +366,7 @@ def _test_tied_weights_sync_with_grad_accum_in_fp32( param.copy_(reference_model.get_parameter(name)) # DDP - model_ddp = torch.nn.parallel.DistributedDataParallel(model, process_group=dpg.dp_pg) + model_ddp = torch.nn.parallel.DistributedDataParallel(model, process_group=parallel_context.dp_pg) module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} reference_module_id_to_prefix = { id(module): f"{module_name}." for module_name, module in reference_model.named_modules() @@ -384,7 +388,7 @@ def _test_tied_weights_sync_with_grad_accum_in_fp32( # Optimizer: We don't actually run the optimizer, we just use it to build the gradient accumulator optimizer = ZeroDistributedOptimizer( - dp_pg=dpg.dp_pg, + dp_pg=parallel_context.dp_pg, named_params_or_groups=named_parameters, optimizer_builder=lambda named_param_groups_1: OptimizerFromGradientAccumulator( gradient_accumulator_builder=lambda named_params: FP32GradientAccumulator( @@ -411,12 +415,12 @@ def _test_tied_weights_sync_with_grad_accum_in_fp32( # We use `model_ddp.module` in order ta have the parameter names without the `module.` prefix accumulator = optimizer.optimizer.gradient_accumulator accumulator.assign_param_offsets( - dp_rank=dist.get_rank(dpg.dp_pg), + dp_rank=dist.get_rank(parallel_context.dp_pg), param_name_to_offsets=optimizer.param_name_to_dp_rank_offsets, ) model_ddp.register_comm_hook( state=FP32GradBucketManager( - dp_pg=dpg.dp_pg, + dp_pg=parallel_context.dp_pg, accumulator=accumulator, param_id_to_name=param_id_to_name, ), @@ -424,7 +428,7 @@ def _test_tied_weights_sync_with_grad_accum_in_fp32( ) # Get infinite dummy data iterator - data_iterator = dummy_infinite_data_loader(pp_pg=dpg.pp_pg, dtype=dtype) # First rank receives data + data_iterator = dummy_infinite_data_loader(pp_pg=parallel_context.pp_pg, dtype=dtype) # First rank receives data n_micro_batches_per_batch = 2 batch = [next(data_iterator) for _ in range(n_micro_batches_per_batch)] @@ -432,7 +436,7 @@ def _test_tied_weights_sync_with_grad_accum_in_fp32( ## Reference model iteration step def forward_backward_reference(mdl, micro_batch): pipeline_engine.train_batch_iter( - mdl, pg=dpg.pp_pg, batch=[micro_batch], nb_microbatches=1, grad_accumulator=None + mdl, pg=parallel_context.pp_pg, batch=[micro_batch], nb_microbatches=1, grad_accumulator=None ) for accum_step in range(n_micro_batches_per_batch - 1): @@ -464,11 +468,15 @@ def forward_backward_reference(mdl, micro_batch): reference_model_accum_ref[name] = ( reference_model_accum_ref[name] + grad.float() if name in reference_model_accum_ref else grad.float() ) - dist.all_reduce(reference_model_accum_ref[name], group=dpg.dp_pg, op=dist.ReduceOp.AVG) + dist.all_reduce(reference_model_accum_ref[name], group=parallel_context.dp_pg, op=dist.ReduceOp.AVG) ## Model iteration step pipeline_engine.train_batch_iter( - model_ddp, pg=dpg.pp_pg, batch=batch, nb_microbatches=n_micro_batches_per_batch, grad_accumulator=accumulator + model_ddp, + pg=parallel_context.pp_pg, + batch=batch, + nb_microbatches=n_micro_batches_per_batch, + grad_accumulator=accumulator, ) for name, param in model_ddp.module.named_parameters(): if param.is_tied: @@ -484,11 +492,11 @@ def forward_backward_reference(mdl, micro_batch): if not reduce_scatter: # Check that the gradients are synchronized across DP - assert_tensor_synced_across_pg(fp32_grad, dpg.dp_pg) + assert_tensor_synced_across_pg(fp32_grad, parallel_context.dp_pg) fp32_grad_ref = reference_model_accum_ref[name] dist.barrier() - + if reduce_scatter: slice_ = slice(*accumulator.param_name_to_offsets[name]) # Check that gradients are correct @@ -511,7 +519,7 @@ def forward_backward_reference(mdl, micro_batch): if not (isinstance(param, NanotronParameter) and param.is_tied): continue - group = dpg.world_ranks_to_pg[group_ranks] + group = parallel_context.world_ranks_to_pg[group_ranks] fp32_grad = accumulator.get_grad_buffer(name=name) with assert_fail_except_rank_with(AssertionError, rank_exception=0, pg=group): @@ -524,7 +532,7 @@ def forward_backward_reference(mdl, micro_batch): # - Translate tied ranks along DP axis to find the DP rank that has the tied weights # - accumulator keeps grads for all DPs, so we can just sync the grads with timeout_after(): - sync_tied_weights_gradients(module=model_ddp.module, dpg=dpg, grad_accumulator=accumulator) + sync_tied_weights_gradients(module=model_ddp.module, dpg=parallel_context, grad_accumulator=accumulator) tied_infos_dict = { ( @@ -552,7 +560,7 @@ def forward_backward_reference(mdl, micro_batch): dp_slice_fp_32_grad_buffer.data_ptr() == fp32_grad.data_ptr() ), "dp_slice_fp_32_grad_buffer and fp32_grad should point to the same memory" - group = dpg.world_ranks_to_pg[group_ranks] + group = parallel_context.world_ranks_to_pg[group_ranks] # Check that fp32 grads for tied weights are synced (Used in optimizer step) # Since we use `reduce_scatter = False` the entire gradient buffer is all reduced, causing it to be synced diff --git a/tests/test_pipeline_parallel.py b/tests/test_pipeline_parallel.py index 4f182082..115f8710 100644 --- a/tests/test_pipeline_parallel.py +++ b/tests/test_pipeline_parallel.py @@ -13,8 +13,8 @@ ) from nanotron.core.parallel.pipeline_parallelism.p2p import P2P from nanotron.core.parallel.pipeline_parallelism.tensor_pointer import TensorPointer -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.utils import init_on_device_and_dtype +from nanotron.distributed import ParallelContext from torch import nn from torch.nn import functional as F @@ -24,26 +24,26 @@ def test_build_and_set_rank(): init_distributed(tp=1, dp=1, pp=2)(_test_build_and_set_rank)() -def _test_build_and_set_rank(dpg: DistributedProcessGroups): +def _test_build_and_set_rank(parallel_context: ParallelContext): device = torch.device("cuda") - p2p = P2P(pg=dpg.pp_pg, device=device) + p2p = P2P(pg=parallel_context.pp_pg, device=device) model = DummyModel(p2p=p2p) # Set the ranks - assert len(model.mlp) == dpg.pp_pg.size() + assert len(model.mlp) == parallel_context.pp_pg.size() with init_on_device_and_dtype(device): - for pp_rank, non_linear in zip(range(dpg.pp_pg.size()), model.mlp): + for pp_rank, non_linear in zip(range(parallel_context.pp_pg.size()), model.mlp): non_linear.linear.build_and_set_rank(pp_rank=pp_rank) non_linear.activation.build_and_set_rank(pp_rank=pp_rank) - model.loss.build_and_set_rank(pp_rank=dpg.pp_pg.size() - 1) + model.loss.build_and_set_rank(pp_rank=parallel_context.pp_pg.size() - 1) # Check that the ranks are set correctly - current_pp_rank = dist.get_rank(dpg.pp_pg) + current_pp_rank = dist.get_rank(parallel_context.pp_pg) assert model.mlp[current_pp_rank].linear.rank == current_pp_rank assert model.mlp[current_pp_rank].activation.rank == current_pp_rank # Check that blocks were built on the correct ranks - for pp_rank, non_linear in zip(range(dpg.pp_pg.size()), model.mlp): + for pp_rank, non_linear in zip(range(parallel_context.pp_pg.size()), model.mlp): if pp_rank == current_pp_rank: assert hasattr(non_linear.linear, "pp_block") assert hasattr(non_linear.activation, "pp_block") @@ -71,12 +71,12 @@ def test_pipeline_engine(pipeline_engine: PipelineEngine, pp: int): init_distributed(tp=1, dp=1, pp=pp)(_test_pipeline_engine)(pipeline_engine=pipeline_engine) -def _test_pipeline_engine(dpg: DistributedProcessGroups, pipeline_engine: PipelineEngine): +def _test_pipeline_engine(parallel_context: ParallelContext, pipeline_engine: PipelineEngine): device = torch.device("cuda") - p2p = P2P(dpg.pp_pg, device=device) + p2p = P2P(parallel_context.pp_pg, device=device) reference_rank = 0 - has_reference_model = dist.get_rank(dpg.pp_pg) == reference_rank - current_pp_rank = dist.get_rank(dpg.pp_pg) + has_reference_model = dist.get_rank(parallel_context.pp_pg) == reference_rank + current_pp_rank = dist.get_rank(parallel_context.pp_pg) # spawn model model = DummyModel(p2p=p2p) @@ -84,12 +84,12 @@ def _test_pipeline_engine(dpg: DistributedProcessGroups, pipeline_engine: Pipeli reference_model = DummyModel(p2p=p2p) # Set the ranks - assert len(model.mlp) == dpg.pp_pg.size() + assert len(model.mlp) == parallel_context.pp_pg.size() with init_on_device_and_dtype(device): - for pp_rank, non_linear in zip(range(dpg.pp_pg.size()), model.mlp): + for pp_rank, non_linear in zip(range(parallel_context.pp_pg.size()), model.mlp): non_linear.linear.build_and_set_rank(pp_rank=pp_rank) non_linear.activation.build_and_set_rank(pp_rank=pp_rank) - model.loss.build_and_set_rank(pp_rank=dpg.pp_pg.size() - 1) + model.loss.build_and_set_rank(pp_rank=parallel_context.pp_pg.size() - 1) # build reference model if has_reference_model: @@ -101,7 +101,7 @@ def _test_pipeline_engine(dpg: DistributedProcessGroups, pipeline_engine: Pipeli # synchronize weights if has_reference_model: with torch.inference_mode(): - for pp_rank in range(dpg.pp_pg.size()): + for pp_rank in range(parallel_context.pp_pg.size()): non_linear = model.mlp[pp_rank] reference_non_linear = reference_model.mlp[pp_rank] if pp_rank == current_pp_rank: @@ -120,14 +120,14 @@ def _test_pipeline_engine(dpg: DistributedProcessGroups, pipeline_engine: Pipeli ) # Get infinite dummy data iterator - data_iterator = dummy_infinite_data_loader(pp_pg=dpg.pp_pg) # First rank receives data + data_iterator = dummy_infinite_data_loader(pp_pg=parallel_context.pp_pg) # First rank receives data # Have at least as many microbatches as PP size. - n_micro_batches_per_batch = dpg.pp_pg.size() + 5 + n_micro_batches_per_batch = parallel_context.pp_pg.size() + 5 batch = [next(data_iterator) for _ in range(n_micro_batches_per_batch)] losses = pipeline_engine.train_batch_iter( - model, pg=dpg.pp_pg, batch=batch, nb_microbatches=n_micro_batches_per_batch, grad_accumulator=None + model, pg=parallel_context.pp_pg, batch=batch, nb_microbatches=n_micro_batches_per_batch, grad_accumulator=None ) # Equivalent on the reference model @@ -167,7 +167,7 @@ def _test_pipeline_engine(dpg: DistributedProcessGroups, pipeline_engine: Pipeli # Check that gradient are the same as reference if has_reference_model: - for pp_rank in range(dpg.pp_pg.size()): + for pp_rank in range(parallel_context.pp_pg.size()): non_linear = model.mlp[pp_rank] reference_non_linear = reference_model.mlp[pp_rank] if pp_rank == current_pp_rank: @@ -187,7 +187,9 @@ def _test_pipeline_engine(dpg: DistributedProcessGroups, pipeline_engine: Pipeli continue weight_grad, bias_grad = p2p.recv_tensors(num_tensors=2, from_rank=pp_rank) - torch.testing.assert_close(weight_grad, reference_non_linear.linear.pp_block.weight.grad, atol=1e-6, rtol=1e-7) + torch.testing.assert_close( + weight_grad, reference_non_linear.linear.pp_block.weight.grad, atol=1e-6, rtol=1e-7 + ) torch.testing.assert_close(bias_grad, reference_non_linear.linear.pp_block.bias.grad, atol=1e-6, rtol=1e-7) else: p2p.send_tensors( @@ -215,7 +217,7 @@ def test_pipeline_engine_with_tensor_that_does_not_require_grad(pipeline_engine: def _test_pipeline_engine_with_tensor_that_does_not_require_grad( - dpg: DistributedProcessGroups, pipeline_engine: PipelineEngine + parallel_context: ParallelContext, pipeline_engine: PipelineEngine ): def activation(x: torch.Tensor, y: torch.Tensor): return {"output": F.sigmoid(x) * y, "y": y} @@ -285,9 +287,9 @@ def forward( return differentiable_tensor device = torch.device("cuda") - p2p = P2P(dpg.pp_pg, device=device) + p2p = P2P(parallel_context.pp_pg, device=device) reference_rank = 0 - current_pp_rank = dist.get_rank(dpg.pp_pg) + current_pp_rank = dist.get_rank(parallel_context.pp_pg) has_reference_model = current_pp_rank == reference_rank # spawn model @@ -296,15 +298,17 @@ def forward( reference_model = DummyModelPassingNonDifferentiableTensor(p2p=p2p) # Set the ranks - assert len(model.mlp) == dpg.pp_pg.size() + 1 + assert len(model.mlp) == parallel_context.pp_pg.size() + 1 # An additional mlp is in the end - mlp_index_pp_rank = [(i, i) for i in range(dpg.pp_pg.size())] + [(dpg.pp_pg.size(), dpg.pp_pg.size() - 1)] + mlp_index_pp_rank = [(i, i) for i in range(parallel_context.pp_pg.size())] + [ + (parallel_context.pp_pg.size(), parallel_context.pp_pg.size() - 1) + ] with init_on_device_and_dtype(device): for (mlp_index, pp_rank), non_linear in zip(mlp_index_pp_rank, model.mlp): non_linear.linear.build_and_set_rank(pp_rank=pp_rank) non_linear.activation.build_and_set_rank(pp_rank=pp_rank) - model.loss.build_and_set_rank(pp_rank=dpg.pp_pg.size() - 1) + model.loss.build_and_set_rank(pp_rank=parallel_context.pp_pg.size() - 1) # build reference model if has_reference_model: @@ -353,15 +357,15 @@ def dummy_infinite_data_loader_with_non_differentiable_tensor( } data_iterator = dummy_infinite_data_loader_with_non_differentiable_tensor( - pp_pg=dpg.pp_pg + pp_pg=parallel_context.pp_pg ) # First rank receives data # Have at least as many microbatches as PP size. - n_micro_batches_per_batch = dpg.pp_pg.size() + 5 + n_micro_batches_per_batch = parallel_context.pp_pg.size() + 5 batch = [next(data_iterator) for _ in range(n_micro_batches_per_batch)] losses = pipeline_engine.train_batch_iter( - model, pg=dpg.pp_pg, batch=batch, nb_microbatches=n_micro_batches_per_batch, grad_accumulator=None + model, pg=parallel_context.pp_pg, batch=batch, nb_microbatches=n_micro_batches_per_batch, grad_accumulator=None ) # Equivalent on the reference model if has_reference_model: @@ -420,7 +424,9 @@ def dummy_infinite_data_loader_with_non_differentiable_tensor( continue weight_grad, bias_grad = p2p.recv_tensors(num_tensors=2, from_rank=pp_rank) - torch.testing.assert_close(weight_grad, reference_non_linear.linear.pp_block.weight.grad, atol=1e-6, rtol=1e-7) + torch.testing.assert_close( + weight_grad, reference_non_linear.linear.pp_block.weight.grad, atol=1e-6, rtol=1e-7 + ) torch.testing.assert_close(bias_grad, reference_non_linear.linear.pp_block.bias.grad, atol=1e-6, rtol=1e-7) else: for (mlp_index, pp_rank) in mlp_index_pp_rank: @@ -436,7 +442,7 @@ def test_pipeline_forward_without_engine(pp: int): init_distributed(pp=pp, dp=1, tp=1)(_test_pipeline_forward_without_engine)() -def _test_pipeline_forward_without_engine(dpg: DistributedProcessGroups): +def _test_pipeline_forward_without_engine(parallel_context: ParallelContext): def activation(x: torch.Tensor, y: torch.Tensor): return {"output": F.sigmoid(x) * y, "y": y} @@ -492,9 +498,9 @@ def forward( return differentiable_tensor device = torch.device("cuda") - p2p = P2P(dpg.pp_pg, device=device) + p2p = P2P(parallel_context.pp_pg, device=device) reference_rank = 0 - current_pp_rank = dist.get_rank(dpg.pp_pg) + current_pp_rank = dist.get_rank(parallel_context.pp_pg) has_reference_model = current_pp_rank == reference_rank # spawn model @@ -503,12 +509,12 @@ def forward( reference_model = DummyModel(p2p=p2p) # Set the ranks - assert len(model.mlp) == dpg.pp_pg.size() + assert len(model.mlp) == parallel_context.pp_pg.size() with init_on_device_and_dtype(device): - for pp_rank, non_linear in zip(range(dpg.pp_pg.size()), model.mlp): + for pp_rank, non_linear in zip(range(parallel_context.pp_pg.size()), model.mlp): non_linear.linear.build_and_set_rank(pp_rank=pp_rank) non_linear.activation.build_and_set_rank(pp_rank=pp_rank) - model.loss.build_and_set_rank(pp_rank=dpg.pp_pg.size() - 1) + model.loss.build_and_set_rank(pp_rank=parallel_context.pp_pg.size() - 1) # build reference model if has_reference_model: @@ -520,7 +526,7 @@ def forward( # synchronize weights if has_reference_model: with torch.inference_mode(): - for pp_rank in range(dpg.pp_pg.size()): + for pp_rank in range(parallel_context.pp_pg.size()): non_linear = model.mlp[pp_rank] reference_non_linear = reference_model.mlp[pp_rank] if pp_rank == current_pp_rank: @@ -555,11 +561,11 @@ def dummy_infinite_data_loader_with_non_differentiable_tensor( } data_iterator = dummy_infinite_data_loader_with_non_differentiable_tensor( - pp_pg=dpg.pp_pg + pp_pg=parallel_context.pp_pg ) # First rank receives data # Have at least as many microbatches as PP size. - n_micro_batches_per_batch = dpg.pp_pg.size() + 5 + n_micro_batches_per_batch = parallel_context.pp_pg.size() + 5 batch = [next(data_iterator) for _ in range(n_micro_batches_per_batch)] @@ -609,7 +615,7 @@ def test_pipeline_engine_diamond(pipeline_engine: PipelineEngine): pass -def _test_pipeline_engine_diamond(dpg: DistributedProcessGroups, pipeline_engine: PipelineEngine): +def _test_pipeline_engine_diamond(parallel_context: ParallelContext, pipeline_engine: PipelineEngine): class DiamondModel(nn.Module): def __init__(self, p2p: P2P): super().__init__() @@ -703,9 +709,9 @@ def forward(self, x): return self.loss(x=out)["output"] device = torch.device("cuda") - p2p = P2P(dpg.pp_pg, device=device) + p2p = P2P(parallel_context.pp_pg, device=device) reference_rank = 0 - current_pp_rank = dist.get_rank(dpg.pp_pg) + current_pp_rank = dist.get_rank(parallel_context.pp_pg) has_reference_model = current_pp_rank == reference_rank # spawn model @@ -714,15 +720,17 @@ def forward(self, x): reference_model = DiamondModel(p2p=p2p) # Set the ranks - assert dpg.pp_pg.size() == len([model.dense_bottom, model.dense_left, model.dense_right, model.dense_top]) - assert dpg.pp_pg.size() == 4 + assert parallel_context.pp_pg.size() == len( + [model.dense_bottom, model.dense_left, model.dense_right, model.dense_top] + ) + assert parallel_context.pp_pg.size() == 4 pp_rank_to_dense_name = ["dense_bottom", "dense_left", "dense_right", "dense_top"] with init_on_device_and_dtype(device): for pp_rank, module_name in enumerate(pp_rank_to_dense_name): non_linear = model.get_submodule(module_name) non_linear.linear.build_and_set_rank(pp_rank=pp_rank) non_linear.activation.build_and_set_rank(pp_rank=pp_rank) - model.loss.build_and_set_rank(pp_rank=dpg.pp_pg.size() - 1) + model.loss.build_and_set_rank(pp_rank=parallel_context.pp_pg.size() - 1) # build reference model if has_reference_model: @@ -768,17 +776,17 @@ def dummy_infinite_data_loader_with_non_differentiable_tensor( } data_iterator = dummy_infinite_data_loader_with_non_differentiable_tensor( - pp_pg=dpg.pp_pg + pp_pg=parallel_context.pp_pg ) # First rank receives data # Have at least as many microbatches as PP size. - n_micro_batches_per_batch = dpg.pp_pg.size() + 5 + n_micro_batches_per_batch = parallel_context.pp_pg.size() + 5 batch = [next(data_iterator) for _ in range(n_micro_batches_per_batch)] losses = pipeline_engine.train_batch_iter( - model, pg=dpg.pp_pg, batch=batch, nb_microbatches=n_micro_batches_per_batch, grad_accumulator=None + model, pg=parallel_context.pp_pg, batch=batch, nb_microbatches=n_micro_batches_per_batch, grad_accumulator=None ) - + # Equivalent on the reference model if has_reference_model: reference_losses = [] diff --git a/tests/test_random_state.py b/tests/test_random_state.py index c35d66ee..841f120d 100644 --- a/tests/test_random_state.py +++ b/tests/test_random_state.py @@ -2,13 +2,13 @@ import torch from helpers.utils import available_gpus, init_distributed from nanotron.core import distributed as dist -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.random import ( RandomStates, branch_random_state, get_current_random_state, get_synced_random_state, ) +from nanotron.distributed import ParallelContext @pytest.mark.skipif(available_gpus() < 2, reason="Testing test_random_state_sync requires at least 2 gpus") @@ -18,10 +18,12 @@ def test_random_state_sync(tp: int, dp: int, pp: int): init_distributed(tp=tp, dp=dp, pp=pp)(_test_random_state_sync)() -def _test_random_state_sync(dpg: DistributedProcessGroups): +def _test_random_state_sync(parallel_context: ParallelContext): current_random_state = get_current_random_state() reference_rank = 0 - pg = next((pg for pg in [dpg.tp_pg, dpg.dp_pg, dpg.pp_pg] if pg.size() == 2)) + pg = next( + (pg for pg in [parallel_context.tp_pg, parallel_context.dp_pg, parallel_context.pp_pg] if pg.size() == 2) + ) # Check that they are not equal across process group if dist.get_rank(pg) == reference_rank: diff --git a/tests/test_serialize.py b/tests/test_serialize.py index 7b517bd5..c66911bd 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -8,6 +8,7 @@ init_distributed, is_dict_equal, ) +from nanotron.constants import CHECKPOINT_VERSION from nanotron.core import distributed as dist from nanotron.core.gradient_accumulator import FP32GradientAccumulator from nanotron.core.optim.named_optimizer import NamedOptimizer @@ -20,8 +21,8 @@ ) from nanotron.core.parallel.sharded_parameters import SplitConfig, create_sharded_parameter_from_config from nanotron.core.parallel.tied_parameters import sync_tied_weights_gradients -from nanotron.core.process_groups import DistributedProcessGroups -from nanotron.core.random import get_current_random_state, get_synced_random_state, RandomStates +from nanotron.core.random import RandomStates, get_current_random_state, get_synced_random_state +from nanotron.distributed import ParallelContext from nanotron.serialize import ( load_optimizer, load_random_states, @@ -30,7 +31,6 @@ save_random_states, save_weights, ) -from nanotron.constants import CHECKPOINT_VERSION from nanotron.serialize.metadata import TensorMetadataV2 from torch.nn.parallel import DistributedDataParallel @@ -54,15 +54,15 @@ def test_save_and_load_model(tp: int, dp: int, pp: int): init_distributed(tp=tp, dp=dp, pp=pp)(_test_save_and_load_model)(test_context=test_context) -def _test_save_and_load_model(dpg: DistributedProcessGroups, test_context: TestContext): - model = init_dummy_model(dpg=dpg) +def _test_save_and_load_model(parallel_context: ParallelContext, test_context: TestContext): + model = init_dummy_model(dpg=parallel_context) store_folder = test_context.get_auto_remove_tmp_dir() # Save - save_weights(model=model, dpg=dpg, root_folder=store_folder) + save_weights(model=model, dpg=parallel_context, root_folder=store_folder) # Load - new_model = init_dummy_model(dpg=dpg) + new_model = init_dummy_model(dpg=parallel_context) # Check that the newly initialised model isn't the same. match, msg = is_dict_equal(new_model.state_dict(), model.state_dict()) @@ -72,7 +72,7 @@ def _test_save_and_load_model(dpg: DistributedProcessGroups, test_context: TestC else: assert not match, "Newly initialised model should not match." - load_weights(model=new_model, dpg=dpg, root_folder=store_folder) + load_weights(model=new_model, dpg=parallel_context, root_folder=store_folder) # Assert the weights are exactly the same after loading match, msg = is_dict_equal(new_model.state_dict(), model.state_dict()) @@ -93,30 +93,32 @@ def test_save_and_load_optimizer(tp: int, dp: int, pp: int): init_distributed(tp=tp, dp=dp, pp=pp)(_test_save_and_load_optimizer)(test_context=test_context) -def _test_save_and_load_optimizer(dpg: DistributedProcessGroups, test_context: TestContext): +def _test_save_and_load_optimizer(parallel_context: ParallelContext, test_context: TestContext): store_folder = test_context.get_auto_remove_tmp_dir() - model = init_dummy_model(dpg=dpg) + model = init_dummy_model(dpg=parallel_context) optimizer = NamedOptimizer( named_params_or_groups=model.named_parameters(), optimizer_builder=lambda params: torch.optim.AdamW(params), ) # Train in order to update the optimizer step a few times - data_loader = iter(dummy_infinite_data_loader(pp_pg=dpg.pp_pg)) + data_loader = iter(dummy_infinite_data_loader(pp_pg=parallel_context.pp_pg)) nb_optim_steps = 3 pipeline_engine = AllForwardAllBackwardPipelineEngine() for _ in range(nb_optim_steps): minibatch = next(data_loader) - _ = pipeline_engine.train_batch_iter(model=model, pg=dpg.pp_pg, batch=[minibatch], nb_microbatches=1, grad_accumulator=None) + _ = pipeline_engine.train_batch_iter( + model=model, pg=parallel_context.pp_pg, batch=[minibatch], nb_microbatches=1, grad_accumulator=None + ) # Manually sync tied parameters - sync_tied_weights_gradients(module=model, dpg=dpg, grad_accumulator=None) + sync_tied_weights_gradients(module=model, dpg=parallel_context, grad_accumulator=None) # Optimizer steps optimizer.step() optimizer.zero_grad() # Save optimizer - save_optimizer(optimizer=optimizer, dpg=dpg, root_folder=store_folder) - dist.barrier(dpg.world_pg) + save_optimizer(optimizer=optimizer, dpg=parallel_context, root_folder=store_folder) + dist.barrier(parallel_context.world_pg) # Generate a new optimizer new_optimizer = NamedOptimizer( @@ -132,7 +134,7 @@ def _test_save_and_load_optimizer(dpg: DistributedProcessGroups, test_context: T else: assert not match, "Newly initialised optimizer should not match." - load_optimizer(optimizer=new_optimizer, dpg=dpg, root_folder=store_folder) + load_optimizer(optimizer=new_optimizer, dpg=parallel_context, root_folder=store_folder) # Assert the optimizer states are exactly the same after loading. match, msg = is_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) @@ -153,34 +155,36 @@ def test_save_zero_optimizer_and_load_optimizer(tp: int, dp: int, pp: int): init_distributed(tp=tp, dp=dp, pp=pp)(_test_save_zero_optimizer_and_load_optimizer)(test_context=test_context) -def _test_save_zero_optimizer_and_load_optimizer(dpg: DistributedProcessGroups, test_context: TestContext): +def _test_save_zero_optimizer_and_load_optimizer(parallel_context: ParallelContext, test_context: TestContext): store_folder = test_context.get_auto_remove_tmp_dir() - model = init_dummy_model(dpg=dpg) + model = init_dummy_model(dpg=parallel_context) optimizer = ZeroDistributedOptimizer( named_params_or_groups=model.named_parameters(), optimizer_builder=lambda named_param_groups: NamedOptimizer( named_params_or_groups=named_param_groups, optimizer_builder=lambda param_groups: torch.optim.AdamW(param_groups), ), - dp_pg=dpg.dp_pg, + dp_pg=parallel_context.dp_pg, ) # Train in order to update the optimizer step a few times - data_loader = iter(dummy_infinite_data_loader(pp_pg=dpg.pp_pg)) + data_loader = iter(dummy_infinite_data_loader(pp_pg=parallel_context.pp_pg)) nb_optim_steps = 3 pipeline_engine = AllForwardAllBackwardPipelineEngine() for _ in range(nb_optim_steps): minibatch = next(data_loader) - _ = pipeline_engine.train_batch_iter(model=model, pg=dpg.pp_pg, batch=[minibatch], nb_microbatches=1, grad_accumulator=None) + _ = pipeline_engine.train_batch_iter( + model=model, pg=parallel_context.pp_pg, batch=[minibatch], nb_microbatches=1, grad_accumulator=None + ) # Manually sync tied parameters - sync_tied_weights_gradients(module=model, dpg=dpg, grad_accumulator=None) + sync_tied_weights_gradients(module=model, dpg=parallel_context, grad_accumulator=None) # Optimizer steps optimizer.step() optimizer.zero_grad() # Save optimizer - save_optimizer(optimizer=optimizer, dpg=dpg, root_folder=store_folder) - dist.barrier(dpg.world_pg) + save_optimizer(optimizer=optimizer, dpg=parallel_context, root_folder=store_folder) + dist.barrier(parallel_context.world_pg) # Generate a new optimizer new_optimizer = ZeroDistributedOptimizer( @@ -189,7 +193,7 @@ def _test_save_zero_optimizer_and_load_optimizer(dpg: DistributedProcessGroups, named_params_or_groups=named_param_groups, optimizer_builder=lambda param_groups: torch.optim.AdamW(param_groups), ), - dp_pg=dpg.dp_pg, + dp_pg=parallel_context.dp_pg, ) # Check that the newly initialised optimizer isn't the same. @@ -200,7 +204,7 @@ def _test_save_zero_optimizer_and_load_optimizer(dpg: DistributedProcessGroups, else: assert not match, "Newly initialised optimizer should not match." - load_optimizer(optimizer=new_optimizer, dpg=dpg, root_folder=store_folder) + load_optimizer(optimizer=new_optimizer, dpg=parallel_context, root_folder=store_folder) # Assert the optimizer states are exactly the same after loading. match, msg = is_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) @@ -225,35 +229,37 @@ def test_save_zero_optimizer_and_load_data_parallel_optimizer(tp: int, dp: int, def _test_save_zero_optimizer_and_load_data_parallel_optimizer( - dpg: DistributedProcessGroups, test_context: TestContext + parallel_context: ParallelContext, test_context: TestContext ): store_folder = test_context.get_auto_remove_tmp_dir() - model = init_dummy_model(dpg=dpg) + model = init_dummy_model(dpg=parallel_context) optimizer = ZeroDistributedOptimizer( named_params_or_groups=model.named_parameters(), optimizer_builder=lambda named_param_groups: NamedOptimizer( named_params_or_groups=named_param_groups, optimizer_builder=lambda param_groups: torch.optim.AdamW(param_groups), ), - dp_pg=dpg.dp_pg, + dp_pg=parallel_context.dp_pg, ) # Train in order to update the optimizer step a few times - data_loader = iter(dummy_infinite_data_loader(pp_pg=dpg.pp_pg)) + data_loader = iter(dummy_infinite_data_loader(pp_pg=parallel_context.pp_pg)) nb_optim_steps = 3 pipeline_engine = AllForwardAllBackwardPipelineEngine() for _ in range(nb_optim_steps): minibatch = next(data_loader) - _ = pipeline_engine.train_batch_iter(model=model, pg=dpg.pp_pg, batch=[minibatch], nb_microbatches=1, grad_accumulator=None) + _ = pipeline_engine.train_batch_iter( + model=model, pg=parallel_context.pp_pg, batch=[minibatch], nb_microbatches=1, grad_accumulator=None + ) # Manually sync tied parameters - sync_tied_weights_gradients(module=model, dpg=dpg, grad_accumulator=None) + sync_tied_weights_gradients(module=model, dpg=parallel_context, grad_accumulator=None) # Optimizer steps optimizer.step() optimizer.zero_grad() # Save optimizer - save_optimizer(optimizer=optimizer, dpg=dpg, root_folder=store_folder) - dist.barrier(dpg.world_pg) + save_optimizer(optimizer=optimizer, dpg=parallel_context, root_folder=store_folder) + dist.barrier(parallel_context.world_pg) # Generate a new optimizer new_optimizer = NamedOptimizer( @@ -269,7 +275,7 @@ def _test_save_zero_optimizer_and_load_data_parallel_optimizer( else: assert not match, "Newly initialised optimizer should not match." - load_optimizer(optimizer=new_optimizer, dpg=dpg, root_folder=store_folder) + load_optimizer(optimizer=new_optimizer, dpg=parallel_context, root_folder=store_folder) # TODO @thomasw21: Compare zero optimizer with non zero @@ -292,28 +298,30 @@ def test_save_data_parallel_optimizer_and_load_zero_optimizer(tp: int, dp: int, def _test_save_data_parallel_optimizer_and_load_zero_optimizer( - dpg: DistributedProcessGroups, test_context: TestContext + parallel_context: ParallelContext, test_context: TestContext ): store_folder = test_context.get_auto_remove_tmp_dir() - model = init_dummy_model(dpg=dpg) + model = init_dummy_model(dpg=parallel_context) optimizer = NamedOptimizer( named_params_or_groups=model.named_parameters(), optimizer_builder=lambda params: torch.optim.AdamW(params), ) # Train in order to update the optimizer step a few times - data_loader = iter(dummy_infinite_data_loader(pp_pg=dpg.pp_pg)) + data_loader = iter(dummy_infinite_data_loader(pp_pg=parallel_context.pp_pg)) nb_optim_steps = 3 pipeline_engine = AllForwardAllBackwardPipelineEngine() for _ in range(nb_optim_steps): minibatch = next(data_loader) - _ = pipeline_engine.train_batch_iter(model=model, pg=dpg.pp_pg, batch=[minibatch], nb_microbatches=1, grad_accumulator=None) + _ = pipeline_engine.train_batch_iter( + model=model, pg=parallel_context.pp_pg, batch=[minibatch], nb_microbatches=1, grad_accumulator=None + ) optimizer.step() optimizer.zero_grad() # Save optimizer - save_optimizer(optimizer=optimizer, dpg=dpg, root_folder=store_folder) - dist.barrier(dpg.world_pg) + save_optimizer(optimizer=optimizer, dpg=parallel_context, root_folder=store_folder) + dist.barrier(parallel_context.world_pg) # Generate a new optimizer new_optimizer = ZeroDistributedOptimizer( @@ -322,7 +330,7 @@ def _test_save_data_parallel_optimizer_and_load_zero_optimizer( named_params_or_groups=named_param_groups, optimizer_builder=lambda param_groups: torch.optim.AdamW(param_groups), ), - dp_pg=dpg.dp_pg, + dp_pg=parallel_context.dp_pg, ) # Check that the newly initialised optimizer isn't the same. @@ -333,7 +341,7 @@ def _test_save_data_parallel_optimizer_and_load_zero_optimizer( else: assert not match, "Newly initialised optimizer should not match." - load_optimizer(optimizer=new_optimizer, dpg=dpg, root_folder=store_folder) + load_optimizer(optimizer=new_optimizer, dpg=parallel_context, root_folder=store_folder) # TODO @thomasw21: Compare zero optimizer with non zero @@ -354,10 +362,10 @@ def test_save_optimizer_with_additional_state_dict_keys(tp: int, dp: int, pp: in ) -def _test_save_optimizer_with_additional_state_dict_keys(dpg: DistributedProcessGroups, test_context: TestContext): +def _test_save_optimizer_with_additional_state_dict_keys(parallel_context: ParallelContext, test_context: TestContext): dtype = torch.float16 store_folder = test_context.get_auto_remove_tmp_dir() - model = init_dummy_model(dpg=dpg, dtype=dtype) + model = init_dummy_model(dpg=parallel_context, dtype=dtype) if isinstance(model, DistributedDataParallel): # Remove the annoying "module." prefix @@ -380,23 +388,27 @@ def _test_save_optimizer_with_additional_state_dict_keys(dpg: DistributedProcess assert len(optimizer.state_dict_additional_keys()) > 0 # Train in order to update the optimizer step a few times - data_loader = iter(dummy_infinite_data_loader(pp_pg=dpg.pp_pg, dtype=dtype)) + data_loader = iter(dummy_infinite_data_loader(pp_pg=parallel_context.pp_pg, dtype=dtype)) nb_optim_steps = 3 pipeline_engine = AllForwardAllBackwardPipelineEngine() for _ in range(nb_optim_steps): minibatch = next(data_loader) _ = pipeline_engine.train_batch_iter( - model=model, pg=dpg.pp_pg, batch=[minibatch], nb_microbatches=1, grad_accumulator=grad_accumulator + model=model, + pg=parallel_context.pp_pg, + batch=[minibatch], + nb_microbatches=1, + grad_accumulator=grad_accumulator, ) # Manually sync tied parameters - sync_tied_weights_gradients(module=normalized_model, dpg=dpg, grad_accumulator=grad_accumulator) + sync_tied_weights_gradients(module=normalized_model, dpg=parallel_context, grad_accumulator=grad_accumulator) # Optimizer steps optimizer.step() optimizer.zero_grad() # Save optimizer - save_optimizer(optimizer=optimizer, dpg=dpg, root_folder=store_folder) - dist.barrier(dpg.world_pg) + save_optimizer(optimizer=optimizer, dpg=parallel_context, root_folder=store_folder) + dist.barrier(parallel_context.world_pg) # Generate a new optimizer new_optimizer = OptimizerFromGradientAccumulator( @@ -416,7 +428,7 @@ def _test_save_optimizer_with_additional_state_dict_keys(dpg: DistributedProcess match, msg = is_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) assert not match, "Newly initialised optimizer should not match." - load_optimizer(optimizer=new_optimizer, dpg=dpg, root_folder=store_folder) + load_optimizer(optimizer=new_optimizer, dpg=parallel_context, root_folder=store_folder) # Assert the optimizer states are exactly the same after loading. match, msg = is_dict_equal(optimizer.state_dict()["state"], new_optimizer.state_dict()["state"]) @@ -451,8 +463,10 @@ def test_save_and_load_random_states(): init_distributed(tp=2, dp=1, pp=1)(_test_save_and_load_random_states)(test_context=test_context) -def _test_save_and_load_random_states(dpg: DistributedProcessGroups, test_context: TestContext): - pg = next((pg for pg in [dpg.tp_pg, dpg.dp_pg, dpg.pp_pg] if pg.size() == 2)) +def _test_save_and_load_random_states(parallel_context: ParallelContext, test_context: TestContext): + pg = next( + (pg for pg in [parallel_context.tp_pg, parallel_context.dp_pg, parallel_context.pp_pg] if pg.size() == 2) + ) random_states = RandomStates( { "my_synced_random_state": get_synced_random_state(random_state=get_current_random_state(), pg=pg), @@ -472,10 +486,10 @@ def _test_save_and_load_random_states(dpg: DistributedProcessGroups, test_contex assert random_states != random_statess[0] # save - save_random_states(random_states=random_states, dpg=dpg, root_folder=store_folder) + save_random_states(random_states=random_states, dpg=parallel_context, root_folder=store_folder) # load - new_random_states = load_random_states(dpg=dpg, root_folder=store_folder) + new_random_states = load_random_states(dpg=parallel_context, root_folder=store_folder) # Each rank has restored it's own random state assert random_states == new_random_states @@ -485,13 +499,13 @@ def test_serialize_deserialize_tensormetadata(): init_distributed(tp=2, dp=1, pp=1)(_test_serialize_deserialize_tensormetadata)(test_context=test_context) -def _test_serialize_deserialize_tensormetadata(dpg: DistributedProcessGroups, test_context: TestContext): +def _test_serialize_deserialize_tensormetadata(parallel_context: ParallelContext, test_context: TestContext): param = torch.nn.Parameter(torch.randn(16, 64)) split_config = SplitConfig( split_dim=0, contiguous_chunks=(8, 8), ) - param = create_sharded_parameter_from_config(parameter=param, pg=dpg.tp_pg, split_config=split_config) + param = create_sharded_parameter_from_config(parameter=param, pg=parallel_context.tp_pg, split_config=split_config) sharded_info = param.get_sharded_info() metadata = TensorMetadataV2( version=CHECKPOINT_VERSION, diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index c3197902..5ddddb82 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -13,7 +13,7 @@ TensorParallelEmbedding, TensorParallelRowLinear, ) -from nanotron.core.process_groups import DistributedProcessGroups +from nanotron.distributed import ParallelContext from torch import nn as torch_nn @@ -26,18 +26,20 @@ def test_column_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearM ) -def _test_column_linear(dpg: DistributedProcessGroups, tp_mode: TensorParallelLinearMode, async_communication: bool): +def _test_column_linear( + parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool +): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" in_features = 2 out_features_per_tp_rank = 3 - out_features = dpg.tp_pg.size() * out_features_per_tp_rank + out_features = parallel_context.tp_pg.size() * out_features_per_tp_rank # Sharded column_linear = TensorParallelColumnLinear( in_features=in_features, out_features=out_features, - pg=dpg.tp_pg, + pg=parallel_context.tp_pg, mode=tp_mode, device="cuda", async_communication=async_communication, @@ -51,12 +53,12 @@ def _test_column_linear(dpg: DistributedProcessGroups, tp_mode: TensorParallelLi dist.all_gather( tensor_list=list(reference_linear.weight.split(out_features_per_tp_rank, dim=0)), tensor=column_linear.weight, - group=dpg.tp_pg, + group=parallel_context.tp_pg, ) dist.all_gather( tensor_list=list(reference_linear.bias.split(out_features_per_tp_rank, dim=0)), tensor=column_linear.bias, - group=dpg.tp_pg, + group=parallel_context.tp_pg, ) # Generate random input @@ -66,19 +68,19 @@ def _test_column_linear(dpg: DistributedProcessGroups, tp_mode: TensorParallelLi batch_size = 5 random_input = torch.randn(batch_size, in_features, device="cuda") # synchronize random_input across tp - dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=dpg.tp_pg) + dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg) sharded_random_input = random_input elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: sharded_batch_size = 5 sharded_random_input = torch.randn(sharded_batch_size, in_features, device="cuda") - if dpg.tp_pg.size() > 1: + if parallel_context.tp_pg.size() > 1: random_input = torch.empty( - sharded_batch_size * dpg.tp_pg.size(), + sharded_batch_size * parallel_context.tp_pg.size(), *(sharded_random_input.shape[1:]), device=sharded_random_input.device, dtype=sharded_random_input.dtype, ) - dist.all_gather_into_tensor(random_input, sharded_random_input, group=dpg.tp_pg) + dist.all_gather_into_tensor(random_input, sharded_random_input, group=parallel_context.tp_pg) else: random_input = sharded_random_input else: @@ -97,23 +99,24 @@ def _test_column_linear(dpg: DistributedProcessGroups, tp_mode: TensorParallelLi sharded_output, reference_output[ :, - dist.get_rank(dpg.tp_pg) - * out_features_per_tp_rank : (dist.get_rank(dpg.tp_pg) + 1) + dist.get_rank(parallel_context.tp_pg) + * out_features_per_tp_rank : (dist.get_rank(parallel_context.tp_pg) + 1) * out_features_per_tp_rank, ], ) except BaseException as e: - print(f"Rank {dist.get_rank(dpg.tp_pg)}: FAIL.") + print(f"Rank {dist.get_rank(parallel_context.tp_pg)}: FAIL.") dist.barrier() raise e - print(f"Rank {dist.get_rank(dpg.tp_pg)}: SUCCESS.") + print(f"Rank {dist.get_rank(parallel_context.tp_pg)}: SUCCESS.") dist.barrier() # Test that we get the same gradient after backward pass sharded_output.sum().backward() reference_output.sum().backward() hidden_dim_slice = slice( - dist.get_rank(dpg.tp_pg) * out_features_per_tp_rank, (dist.get_rank(dpg.tp_pg) + 1) * out_features_per_tp_rank + dist.get_rank(parallel_context.tp_pg) * out_features_per_tp_rank, + (dist.get_rank(parallel_context.tp_pg) + 1) * out_features_per_tp_rank, ) torch.testing.assert_close( column_linear.weight.grad, @@ -130,7 +133,8 @@ def _test_column_linear(dpg: DistributedProcessGroups, tp_mode: TensorParallelLi ) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: batch_dim_slice = slice( - dist.get_rank(dpg.tp_pg) * sharded_batch_size, (dist.get_rank(dpg.tp_pg) + 1) * sharded_batch_size + dist.get_rank(parallel_context.tp_pg) * sharded_batch_size, + (dist.get_rank(parallel_context.tp_pg) + 1) * sharded_batch_size, ) torch.testing.assert_close( sharded_random_input.grad, @@ -147,7 +151,14 @@ def _test_column_linear(dpg: DistributedProcessGroups, tp_mode: TensorParallelLi pytest.param(TensorParallelLinearMode.ALL_REDUCE, False, does_not_raise()), pytest.param(TensorParallelLinearMode.REDUCE_SCATTER, False, does_not_raise()), pytest.param(TensorParallelLinearMode.REDUCE_SCATTER, True, does_not_raise()), - pytest.param(TensorParallelLinearMode.ALL_REDUCE, True, pytest.raises(AssertionError, match=r"Cf this: https://github.com/huggingface/nanotron/blob/bf82cded9eef1ba77864b48e65bffefad4076339/src/nanotron/core/parallel/tensor_parallelism/nn.py#L132")), + pytest.param( + TensorParallelLinearMode.ALL_REDUCE, + True, + pytest.raises( + AssertionError, + match=r"Cf this: https://github.com/huggingface/nanotron/blob/bf82cded9eef1ba77864b48e65bffefad4076339/src/nanotron/core/parallel/tensor_parallelism/nn.py#L132", + ), + ), ], ) def test_row_linear( @@ -159,19 +170,19 @@ def test_row_linear( def _test_row_linear( - dpg: DistributedProcessGroups, tp_mode: TensorParallelLinearMode, async_communication: bool, expectation: Any + parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool, expectation: Any ): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" out_features = 3 in_features_per_rank = 2 - in_features = dpg.tp_pg.size() * in_features_per_rank + in_features = parallel_context.tp_pg.size() * in_features_per_rank # Sharded row_linear = TensorParallelRowLinear( in_features=in_features, out_features=out_features, - pg=dpg.tp_pg, + pg=parallel_context.tp_pg, mode=tp_mode, device="cuda", async_communication=async_communication, @@ -182,38 +193,41 @@ def _test_row_linear( # Copy weights/bias from sharded to un-sharded with torch.inference_mode(): - dist.all_reduce(tensor=reference_linear.weight, op=dist.ReduceOp.SUM, group=dpg.tp_pg) + dist.all_reduce(tensor=reference_linear.weight, op=dist.ReduceOp.SUM, group=parallel_context.tp_pg) row_linear.weight.copy_( reference_linear.weight[ :, - dist.get_rank(dpg.tp_pg) - * in_features_per_rank : (dist.get_rank(dpg.tp_pg) + 1) + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) * in_features_per_rank, ] ) # broadcast bias from rank 0, and the other don't have bias - if dist.get_rank(dpg.tp_pg) == 0: + if dist.get_rank(parallel_context.tp_pg) == 0: row_linear.bias.copy_(reference_linear.bias) dist.broadcast( tensor=reference_linear.bias, - src=get_global_rank(group=dpg.tp_pg, group_rank=0), - group=dpg.tp_pg, + src=get_global_rank(group=parallel_context.tp_pg, group_rank=0), + group=parallel_context.tp_pg, ) # Generate random input if tp_mode is TensorParallelLinearMode.ALL_REDUCE: batch_size = 5 elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - batch_size = 5 * dpg.tp_pg.size() + batch_size = 5 * parallel_context.tp_pg.size() else: raise ValueError() random_input = torch.randn(batch_size, in_features, device="cuda") # synchronize random_input across tp - dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=dpg.tp_pg) + dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg) # Row linear receives as input sharded input random_sharded_input = random_input[ - :, dist.get_rank(dpg.tp_pg) * in_features_per_rank : (dist.get_rank(dpg.tp_pg) + 1) * in_features_per_rank + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, ] # Test that we get the same output after forward pass @@ -225,10 +239,12 @@ def _test_row_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: sharded_reference_output = reference_output elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - assert batch_size % dpg.tp_pg.size() == 0 - sharded_batch_size = batch_size // dpg.tp_pg.size() + assert batch_size % parallel_context.tp_pg.size() == 0 + sharded_batch_size = batch_size // parallel_context.tp_pg.size() sharded_reference_output = reference_output[ - dist.get_rank(dpg.tp_pg) * sharded_batch_size : (dist.get_rank(dpg.tp_pg) + 1) * sharded_batch_size + dist.get_rank(parallel_context.tp_pg) + * sharded_batch_size : (dist.get_rank(parallel_context.tp_pg) + 1) + * sharded_batch_size ] else: raise ValueError(f"Unsupported mode: {tp_mode}") @@ -246,12 +262,12 @@ def _test_row_linear( row_linear.weight.grad, reference_linear.weight.grad[ :, - dist.get_rank(dpg.tp_pg) - * in_features_per_rank : (dist.get_rank(dpg.tp_pg) + 1) + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) * in_features_per_rank, ], ) - if dist.get_rank(dpg.tp_pg) == 0: + if dist.get_rank(parallel_context.tp_pg) == 0: torch.testing.assert_close( row_linear.bias.grad, reference_linear.bias.grad, @@ -266,14 +282,18 @@ def test_tensor_parallel_embedding(tp: int, dp: int, pp: int, tp_mode: TensorPar init_distributed(tp=tp, dp=dp, pp=pp)(_test_tensor_parallel_embedding)(tp_mode=tp_mode) -def _test_tensor_parallel_embedding(dpg: DistributedProcessGroups, tp_mode: TensorParallelLinearMode): +def _test_tensor_parallel_embedding(parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode): num_embeddings_per_rank = 100 embedding_dim = 3 - num_embeddings = dpg.tp_pg.size() * num_embeddings_per_rank + num_embeddings = parallel_context.tp_pg.size() * num_embeddings_per_rank # Sharded sharded_embedding = TensorParallelEmbedding( - num_embeddings=num_embeddings, embedding_dim=embedding_dim, pg=dpg.tp_pg, mode=tp_mode, device="cuda" + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + pg=parallel_context.tp_pg, + mode=tp_mode, + device="cuda", ) # Un-sharded @@ -281,11 +301,11 @@ def _test_tensor_parallel_embedding(dpg: DistributedProcessGroups, tp_mode: Tens # Copy weights/bias from sharded to un-sharded with torch.inference_mode(): - dist.all_reduce(tensor=reference_embedding.weight, op=dist.ReduceOp.SUM, group=dpg.tp_pg) + dist.all_reduce(tensor=reference_embedding.weight, op=dist.ReduceOp.SUM, group=parallel_context.tp_pg) sharded_embedding.weight.copy_( reference_embedding.weight[ - dist.get_rank(dpg.tp_pg) - * num_embeddings_per_rank : (dist.get_rank(dpg.tp_pg) + 1) + dist.get_rank(parallel_context.tp_pg) + * num_embeddings_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) * num_embeddings_per_rank, :, ] @@ -296,11 +316,11 @@ def _test_tensor_parallel_embedding(dpg: DistributedProcessGroups, tp_mode: Tens if tp_mode is TensorParallelLinearMode.ALL_REDUCE: batch_size = 5 elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - batch_size = 5 * dpg.tp_pg.size() + batch_size = 5 * parallel_context.tp_pg.size() else: raise ValueError(f"Unsupported mode: {tp_mode}") random_input = torch.randint(low=0, high=num_embeddings, size=(batch_size,), device="cuda") - dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=dpg.tp_pg) + dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg) # Test that we get the same output after forward pass sharded_output = sharded_embedding(random_input) @@ -311,13 +331,17 @@ def _test_tensor_parallel_embedding(dpg: DistributedProcessGroups, tp_mode: Tens sharded_reference_output = reference_output sharded_weights = weights elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - assert batch_size % dpg.tp_pg.size() == 0 - sharded_batch_size = batch_size // dpg.tp_pg.size() + assert batch_size % parallel_context.tp_pg.size() == 0 + sharded_batch_size = batch_size // parallel_context.tp_pg.size() sharded_reference_output = reference_output[ - dist.get_rank(dpg.tp_pg) * sharded_batch_size : (dist.get_rank(dpg.tp_pg) + 1) * sharded_batch_size + dist.get_rank(parallel_context.tp_pg) + * sharded_batch_size : (dist.get_rank(parallel_context.tp_pg) + 1) + * sharded_batch_size ] sharded_weights = weights[ - dist.get_rank(dpg.tp_pg) * sharded_batch_size : (dist.get_rank(dpg.tp_pg) + 1) * sharded_batch_size + dist.get_rank(parallel_context.tp_pg) + * sharded_batch_size : (dist.get_rank(parallel_context.tp_pg) + 1) + * sharded_batch_size ] else: raise ValueError(f"Unsupported mode: {tp_mode}") @@ -331,8 +355,8 @@ def _test_tensor_parallel_embedding(dpg: DistributedProcessGroups, tp_mode: Tens torch.testing.assert_close( sharded_embedding.weight.grad, reference_embedding.weight.grad[ - dist.get_rank(dpg.tp_pg) - * num_embeddings_per_rank : (dist.get_rank(dpg.tp_pg) + 1) + dist.get_rank(parallel_context.tp_pg) + * num_embeddings_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) * num_embeddings_per_rank, :, ], diff --git a/tests/test_tie_weights.py b/tests/test_tie_weights.py index 455f1d3a..0b733657 100644 --- a/tests/test_tie_weights.py +++ b/tests/test_tie_weights.py @@ -9,7 +9,7 @@ sync_tied_weights_gradients, tie_parameters, ) -from nanotron.core.process_groups import DistributedProcessGroups +from nanotron.distributed import ParallelContext from torch import nn @@ -17,18 +17,21 @@ def test_tie_weight_in_same_device(): init_distributed(tp=1, dp=1, pp=1)(_test_tie_weight_in_same_device)() -def _test_tie_weight_in_same_device(dpg: DistributedProcessGroups): +def _test_tie_weight_in_same_device(parallel_context: ParallelContext): model = nn.ModuleDict({"dense0": nn.Linear(10, 10, device="cuda"), "dense1": nn.Linear(10, 10, device="cuda")}) # Tie weights/bias tie_parameters( root_module=model, ties=[("dense0.weight", (0,)), ("dense1.weight", (0,))], - dpg=dpg, + dpg=parallel_context, reduce_op=dist.ReduceOp.SUM, ) tie_parameters( - root_module=model, ties=[("dense0.bias", (0,)), ("dense1.bias", (0,))], dpg=dpg, reduce_op=dist.ReduceOp.SUM + root_module=model, + ties=[("dense0.bias", (0,)), ("dense1.bias", (0,))], + dpg=parallel_context, + reduce_op=dist.ReduceOp.SUM, ) weight0 = model.get_parameter("dense0.weight") @@ -45,8 +48,8 @@ def test_tie_weight_in_different_device(): init_distributed(tp=1, dp=1, pp=2)(_test_tie_weight_in_different_device)() -def _test_tie_weight_in_different_device(dpg: DistributedProcessGroups): - if dist.get_rank(dpg.pp_pg) == 0: +def _test_tie_weight_in_different_device(parallel_context: ParallelContext): + if dist.get_rank(parallel_context.pp_pg) == 0: model = nn.ModuleDict( { "dense0": nn.Linear(10, 10, device="cuda"), @@ -63,17 +66,20 @@ def _test_tie_weight_in_different_device(dpg: DistributedProcessGroups): tie_parameters( root_module=model, ties=[("dense0.weight", (0,)), ("dense1.weight", (1,))], - dpg=dpg, + dpg=parallel_context, reduce_op=dist.ReduceOp.SUM, ) tie_parameters( - root_module=model, ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))], dpg=dpg, reduce_op=dist.ReduceOp.SUM + root_module=model, + ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))], + dpg=parallel_context, + reduce_op=dist.ReduceOp.SUM, ) - group = dpg.world_ranks_to_pg[(0, 1)] + group = parallel_context.world_ranks_to_pg[(0, 1)] # Check that model weights are not in fact synchronized - if dist.get_rank(dpg.pp_pg) == 0: + if dist.get_rank(parallel_context.pp_pg) == 0: weight = model.dense0.weight bias = model.dense0.bias else: @@ -98,7 +104,7 @@ def _test_tie_weight_in_different_device(dpg: DistributedProcessGroups): ).items(), key=lambda x: x[0], ): - group = dpg.world_ranks_to_pg[group_ranks] + group = parallel_context.world_ranks_to_pg[group_ranks] dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group) # We check that we use the same parameter for both linear layers @@ -110,8 +116,8 @@ def test_tie_weight_across_dp_is_impossible(): init_distributed(tp=1, dp=2, pp=1)(_test_tie_weight_across_dp_is_impossible)() -def _test_tie_weight_across_dp_is_impossible(dpg: DistributedProcessGroups): - if dist.get_rank(dpg.dp_pg) == 0: +def _test_tie_weight_across_dp_is_impossible(parallel_context: ParallelContext): + if dist.get_rank(parallel_context.dp_pg) == 0: model = nn.ModuleDict( { "dense0": nn.Linear(10, 10, device="cuda"), @@ -129,14 +135,14 @@ def _test_tie_weight_across_dp_is_impossible(dpg: DistributedProcessGroups): tie_parameters( root_module=model, ties=[("dense0.weight", (0,)), ("dense1.weight", (1,))], - dpg=dpg, + dpg=parallel_context, reduce_op=dist.ReduceOp.SUM, ) with assert_fail_with(AssertionError): tie_parameters( root_module=model, ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))], - dpg=dpg, + dpg=parallel_context, reduce_op=dist.ReduceOp.SUM, ) @@ -145,8 +151,8 @@ def test_tie_weight_in_different_device_have_gradients_synchronized(): init_distributed(tp=1, dp=1, pp=2)(_test_tie_weight_in_different_device_have_gradients_synchronized)() -def _test_tie_weight_in_different_device_have_gradients_synchronized(dpg: DistributedProcessGroups): - if dist.get_rank(dpg.pp_pg) == 0: +def _test_tie_weight_in_different_device_have_gradients_synchronized(parallel_context: ParallelContext): + if dist.get_rank(parallel_context.pp_pg) == 0: model = nn.ModuleDict( { "dense0": nn.Linear(10, 10, device="cuda"), @@ -163,17 +169,20 @@ def _test_tie_weight_in_different_device_have_gradients_synchronized(dpg: Distri tie_parameters( root_module=model, ties=[("dense0.weight", (0,)), ("dense1.weight", (1,))], - dpg=dpg, + dpg=parallel_context, reduce_op=dist.ReduceOp.SUM, ) tie_parameters( - root_module=model, ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))], dpg=dpg, reduce_op=dist.ReduceOp.SUM + root_module=model, + ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))], + dpg=parallel_context, + reduce_op=dist.ReduceOp.SUM, ) - group = dpg.world_ranks_to_pg[(0, 1)] + group = parallel_context.world_ranks_to_pg[(0, 1)] # Check that model weights are not in fact synchronized - if dist.get_rank(dpg.pp_pg) == 0: + if dist.get_rank(parallel_context.pp_pg) == 0: weight = model.dense0.weight bias = model.dense0.bias else: @@ -192,7 +201,7 @@ def _test_tie_weight_in_different_device_have_gradients_synchronized(dpg: Distri # Compute gradient input_ = torch.randn(13, 10, device="cuda") - if dist.get_rank(dpg.pp_pg) == 0: + if dist.get_rank(parallel_context.pp_pg) == 0: out = model.dense0(input_) else: out = model.dense1(input_) @@ -200,7 +209,7 @@ def _test_tie_weight_in_different_device_have_gradients_synchronized(dpg: Distri # sync gradients # TODO @thomasw21: This should be done in hooks - sync_tied_weights_gradients(model, dpg=dpg, grad_accumulator=None) + sync_tied_weights_gradients(model, dpg=parallel_context, grad_accumulator=None) # Check that we have gradient assert weight.grad is not None diff --git a/tests/test_zero.py b/tests/test_zero.py index 04abcf8c..d7f6675e 100644 --- a/tests/test_zero.py +++ b/tests/test_zero.py @@ -16,8 +16,8 @@ from nanotron.core.parallel.tensor_parallelism import nn from nanotron.core.parallel.tensor_parallelism.enum import TensorParallelLinearMode from nanotron.core.parallel.tied_parameters import sync_tied_weights_gradients -from nanotron.core.process_groups import DistributedProcessGroups -from nanotron.core.random import branch_random_state, get_current_random_state, get_synced_random_state, RandomStates +from nanotron.core.random import RandomStates, branch_random_state, get_current_random_state, get_synced_random_state +from nanotron.distributed import ParallelContext from torch import nn as torch_nn from torch.nn.parallel import DistributedDataParallel @@ -27,20 +27,20 @@ def test_zero_optimizer(tp: int, dp: int, pp: int): init_distributed(pp=pp, dp=dp, tp=tp)(_test_zero_optimizer)() -def _test_zero_optimizer(dpg: DistributedProcessGroups): - model = init_dummy_model(dpg=dpg) +def _test_zero_optimizer(parallel_context: ParallelContext): + model = init_dummy_model(dpg=parallel_context) optimizer = ZeroDistributedOptimizer( named_params_or_groups=model.named_parameters(), optimizer_builder=lambda named_param_groups: NamedOptimizer( named_params_or_groups=named_param_groups, optimizer_builder=lambda param_groups: torch.optim.AdamW(param_groups), ), - dp_pg=dpg.dp_pg, + dp_pg=parallel_context.dp_pg, ) index_to_name = [name for name, _ in model.named_parameters()] # reference model - reference_model = init_dummy_model(dpg=dpg) + reference_model = init_dummy_model(dpg=parallel_context) reference_optimizer = torch.optim.AdamW(reference_model.parameters()) # sync weights between reference_model and model @@ -50,7 +50,7 @@ def _test_zero_optimizer(dpg: DistributedProcessGroups): param.copy_(ref_param) # Get infinite dummy data iterator - data_loader = iter(dummy_infinite_data_loader(pp_pg=dpg.pp_pg)) + data_loader = iter(dummy_infinite_data_loader(pp_pg=parallel_context.pp_pg)) nb_optim_steps = 3 batches = [[next(data_loader)] for _ in range(nb_optim_steps)] pipeline_engine = AllForwardAllBackwardPipelineEngine() @@ -61,9 +61,11 @@ def _test_zero_optimizer(dpg: DistributedProcessGroups): old_named_params = {name: param.detach().clone() for name, param in model.named_parameters()} # Run forward/backward - losses = pipeline_engine.train_batch_iter(model=model, pg=dpg.pp_pg, batch=batch, nb_microbatches=1, grad_accumulator=None) + losses = pipeline_engine.train_batch_iter( + model=model, pg=parallel_context.pp_pg, batch=batch, nb_microbatches=1, grad_accumulator=None + ) ref_losses = pipeline_engine.train_batch_iter( - model=reference_model, pg=dpg.pp_pg, batch=batch, nb_microbatches=1, grad_accumulator=None + model=reference_model, pg=parallel_context.pp_pg, batch=batch, nb_microbatches=1, grad_accumulator=None ) # Check loss match @@ -73,25 +75,29 @@ def _test_zero_optimizer(dpg: DistributedProcessGroups): for loss, ref_loss in zip(losses, ref_losses): assert isinstance(loss["loss"], torch.Tensor) assert isinstance(ref_loss["loss"], torch.Tensor) - torch.testing.assert_close(loss["loss"], ref_loss["loss"], atol=0, rtol=0, msg=lambda msg: f"At iteration {i}, {msg}") + torch.testing.assert_close( + loss["loss"], ref_loss["loss"], atol=0, rtol=0, msg=lambda msg: f"At iteration {i}, {msg}" + ) # Manually sync tied parameters' gradients - sync_tied_weights_gradients(module=model, dpg=dpg, grad_accumulator=None) - sync_tied_weights_gradients(module=reference_model, dpg=dpg, grad_accumulator=None) + sync_tied_weights_gradients(module=model, dpg=parallel_context, grad_accumulator=None) + sync_tied_weights_gradients(module=reference_model, dpg=parallel_context, grad_accumulator=None) # We rely on DDP to synchronize gradients across DP. We only need to manually synchronize them if we don't use DDP. if not isinstance(model, DistributedDataParallel): - sync_gradients_across_dp(model, dp_pg=dpg.dp_pg, reduce_op=dist.ReduceOp.AVG, grad_accumulator=None) + sync_gradients_across_dp( + model, dp_pg=parallel_context.dp_pg, reduce_op=dist.ReduceOp.AVG, grad_accumulator=None + ) if not isinstance(reference_model, DistributedDataParallel): sync_gradients_across_dp( - reference_model, dp_pg=dpg.dp_pg, reduce_op=dist.ReduceOp.AVG, grad_accumulator=None + reference_model, dp_pg=parallel_context.dp_pg, reduce_op=dist.ReduceOp.AVG, grad_accumulator=None ) # Check gradients are synced across DP for name, param in model.named_parameters(): - assert_tensor_equal_over_group(param.grad, group=dpg.dp_pg) + assert_tensor_equal_over_group(param.grad, group=parallel_context.dp_pg) for ref_name, ref_param in reference_model.named_parameters(): - assert_tensor_equal_over_group(ref_param.grad, group=dpg.dp_pg) + assert_tensor_equal_over_group(ref_param.grad, group=parallel_context.dp_pg) # Check gradients are the same with reference_model for (name, param), (ref_name, ref_param) in zip(model.named_parameters(), reference_model.named_parameters()): @@ -104,7 +110,7 @@ def _test_zero_optimizer(dpg: DistributedProcessGroups): assert len(list(model.named_parameters())) == len(optimizer.param_groups[0]["params"]) with torch.no_grad(): for (name, param), sliced_param in zip(model.named_parameters(), optimizer.param_groups[0]["params"]): - offsets = optimizer.param_name_to_dp_rank_offsets[name][dist.get_rank(dpg.dp_pg)] + offsets = optimizer.param_name_to_dp_rank_offsets[name][dist.get_rank(parallel_context.dp_pg)] # Check that weights are the same expected_slice = param.view(-1)[slice(*offsets)].view_as(sliced_param) @@ -140,12 +146,12 @@ def _test_zero_optimizer(dpg: DistributedProcessGroups): # Check that params are synced across DP for name, param in model.named_parameters(): - assert_tensor_equal_over_group(param, group=dpg.dp_pg) + assert_tensor_equal_over_group(param, group=parallel_context.dp_pg) assert param.grad is None # Check that gradients are reset for ref_name, ref_param in reference_model.named_parameters(): - assert_tensor_equal_over_group(ref_param, group=dpg.dp_pg) + assert_tensor_equal_over_group(ref_param, group=parallel_context.dp_pg) assert ref_param.grad is None for param_group in optimizer.param_groups: for param in param_group["params"]: @@ -173,7 +179,7 @@ def _test_zero_optimizer(dpg: DistributedProcessGroups): ref_optim_state = ref_state[index] name = index_to_name[index] - offsets = optimizer.param_name_to_dp_rank_offsets[name][dist.get_rank(dpg.dp_pg)] + offsets = optimizer.param_name_to_dp_rank_offsets[name][dist.get_rank(parallel_context.dp_pg)] assert set(optim_state) == set(ref_optim_state) @@ -201,23 +207,27 @@ def test_zero_optimizer_with_tp( def _test_zero_optimizer_with_tp( - dpg: DistributedProcessGroups, tp_mode: TensorParallelLinearMode, async_communication: bool + parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool ): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" model = torch_nn.Sequential( nn.TensorParallelColumnLinear( in_features=5, - out_features=dpg.tp_pg.size(), + out_features=parallel_context.tp_pg.size(), mode=tp_mode, - pg=dpg.tp_pg, + pg=parallel_context.tp_pg, device="cuda", async_communication=async_communication, ), # We choose `sigmoid` instead of `relu` since `relu` can result in a sparse gradient, causing no update to certain parameters torch_nn.Sigmoid(), nn.TensorParallelRowLinear( - in_features=dpg.tp_pg.size(), out_features=3, mode=tp_mode, pg=dpg.tp_pg, device="cuda" + in_features=parallel_context.tp_pg.size(), + out_features=3, + mode=tp_mode, + pg=parallel_context.tp_pg, + device="cuda", ), ) optimizer = ZeroDistributedOptimizer( @@ -226,16 +236,16 @@ def _test_zero_optimizer_with_tp( named_params_or_groups=named_param_groups, optimizer_builder=lambda param_groups: torch.optim.AdamW(param_groups), ), - dp_pg=dpg.dp_pg, + dp_pg=parallel_context.dp_pg, ) optimizer_name_to_id = {v: k for k, v in optimizer.optimizer.id_to_name.items()} assert len(optimizer_name_to_id) == len(optimizer.id_to_name) # reference model reference_model = torch_nn.Sequential( - torch_nn.Linear(in_features=5, out_features=dpg.tp_pg.size(), device="cuda"), + torch_nn.Linear(in_features=5, out_features=parallel_context.tp_pg.size(), device="cuda"), torch_nn.Sigmoid(), - torch_nn.Linear(in_features=dpg.tp_pg.size(), out_features=3, device="cuda"), + torch_nn.Linear(in_features=parallel_context.tp_pg.size(), out_features=3, device="cuda"), ) for module in reference_model.modules(): for name, param in module.named_parameters(recurse=False): @@ -248,7 +258,7 @@ def _test_zero_optimizer_with_tp( # sync parameters with torch.no_grad(): for ref_name, ref_param in reference_model.named_parameters(): - dist.all_reduce(ref_param, op=dist.ReduceOp.AVG, group=dpg.world_pg) + dist.all_reduce(ref_param, op=dist.ReduceOp.AVG, group=parallel_context.world_pg) for (name, param), (ref_name, ref_param) in zip(model.named_parameters(), reference_model.named_parameters()): assert name == ref_name @@ -266,14 +276,16 @@ def _test_zero_optimizer_with_tp( # Get infinite dummy data iterator, it has to be synced across TP random_states = RandomStates( { - "tp_synced": get_synced_random_state(random_state=get_current_random_state(), pg=dpg.tp_pg), + "tp_synced": get_synced_random_state(random_state=get_current_random_state(), pg=parallel_context.tp_pg), } ) - batch_size = 2 * dpg.tp_pg.size() if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER else 7 + batch_size = 2 * parallel_context.tp_pg.size() if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER else 7 with branch_random_state(random_states=random_states, key="tp_synced", enabled=True): nb_optim_steps = 3 batches = [ - torch.randn(batch_size, 5, device="cuda") if dist.get_rank(dpg.pp_pg) == 0 else TensorPointer(0) + torch.randn(batch_size, 5, device="cuda") + if dist.get_rank(parallel_context.pp_pg) == 0 + else TensorPointer(0) for _ in range(nb_optim_steps) ] @@ -285,9 +297,13 @@ def _test_zero_optimizer_with_tp( # Run forward pass if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: batch_size = batch.shape[0] - assert batch_size % dpg.tp_pg.size() == 0 - step = batch_size // dpg.tp_pg.size() - loss = model(batch[dist.get_rank(dpg.tp_pg) * step : (dist.get_rank(dpg.tp_pg) + 1) * step]) + assert batch_size % parallel_context.tp_pg.size() == 0 + step = batch_size // parallel_context.tp_pg.size() + loss = model( + batch[ + dist.get_rank(parallel_context.tp_pg) * step : (dist.get_rank(parallel_context.tp_pg) + 1) * step + ] + ) else: loss = model(batch) ref_loss = reference_model(batch) @@ -303,33 +319,37 @@ def _test_zero_optimizer_with_tp( assert isinstance(ref_loss, torch.Tensor) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: batch_size = batch.shape[0] - assert batch_size % dpg.tp_pg.size() == 0 - step = batch_size // dpg.tp_pg.size() + assert batch_size % parallel_context.tp_pg.size() == 0 + step = batch_size // parallel_context.tp_pg.size() torch.testing.assert_close( loss, - ref_loss[dist.get_rank(dpg.tp_pg) * step : (dist.get_rank(dpg.tp_pg) + 1) * step], + ref_loss[ + dist.get_rank(parallel_context.tp_pg) * step : (dist.get_rank(parallel_context.tp_pg) + 1) * step + ], msg=lambda msg: f"At iteration {i}, {msg}", ) else: torch.testing.assert_close(loss, ref_loss, msg=lambda msg: f"At iteration {i}, {msg}") # Manually sync tied parameters - sync_tied_weights_gradients(module=model, dpg=dpg, grad_accumulator=None) - sync_tied_weights_gradients(module=reference_model, dpg=dpg, grad_accumulator=None) + sync_tied_weights_gradients(module=model, dpg=parallel_context, grad_accumulator=None) + sync_tied_weights_gradients(module=reference_model, dpg=parallel_context, grad_accumulator=None) # We rely on DDP to synchronize gradients across DP. We only need to manually synchronize them if we don't use DDP. if not isinstance(model, DistributedDataParallel): - sync_gradients_across_dp(model, dp_pg=dpg.dp_pg, reduce_op=dist.ReduceOp.AVG, grad_accumulator=None) + sync_gradients_across_dp( + model, dp_pg=parallel_context.dp_pg, reduce_op=dist.ReduceOp.AVG, grad_accumulator=None + ) if not isinstance(reference_model, DistributedDataParallel): sync_gradients_across_dp( - reference_model, dp_pg=dpg.dp_pg, reduce_op=dist.ReduceOp.AVG, grad_accumulator=None + reference_model, dp_pg=parallel_context.dp_pg, reduce_op=dist.ReduceOp.AVG, grad_accumulator=None ) # Check gradients are synced across DP for name, param in model.named_parameters(): - assert_tensor_equal_over_group(param.grad, group=dpg.dp_pg) + assert_tensor_equal_over_group(param.grad, group=parallel_context.dp_pg) for ref_name, ref_param in reference_model.named_parameters(): - assert_tensor_equal_over_group(ref_param.grad, group=dpg.dp_pg) + assert_tensor_equal_over_group(ref_param.grad, group=parallel_context.dp_pg) # Check gradients are the same with reference_model for (name, param), (ref_name, ref_param) in zip(model.named_parameters(), reference_model.named_parameters()): @@ -352,13 +372,13 @@ def _test_zero_optimizer_with_tp( optim_param_id_to_param = {id(param): param for param in optimizer.param_groups[0]["params"]} assert len(optim_param_id_to_param) == len(optimizer.param_groups[0]["params"]) for name, param in model.named_parameters(): - if dist.get_rank(dpg.dp_pg) not in optimizer.param_name_to_dp_rank_offsets[name]: + if dist.get_rank(parallel_context.dp_pg) not in optimizer.param_name_to_dp_rank_offsets[name]: assert name not in optimizer_name_to_id continue param_id = optimizer_name_to_id[name] sliced_param = optim_param_id_to_param[param_id] - offsets = optimizer.param_name_to_dp_rank_offsets[name][dist.get_rank(dpg.dp_pg)] + offsets = optimizer.param_name_to_dp_rank_offsets[name][dist.get_rank(parallel_context.dp_pg)] # Check that weights share the same storage expected_slice = param.view(-1)[slice(*offsets)].view_as(sliced_param) @@ -394,12 +414,12 @@ def _test_zero_optimizer_with_tp( # Check that params are synced across DP for name, param in model.named_parameters(): - assert_tensor_equal_over_group(param, group=dpg.dp_pg) + assert_tensor_equal_over_group(param, group=parallel_context.dp_pg) assert param.grad is None # Check that gradients are reset for ref_name, ref_param in reference_model.named_parameters(): - assert_tensor_equal_over_group(ref_param, group=dpg.dp_pg) + assert_tensor_equal_over_group(ref_param, group=parallel_context.dp_pg) assert ref_param.grad is None for param_group in optimizer.param_groups: for param in param_group["params"]: @@ -446,7 +466,7 @@ def _test_zero_optimizer_with_tp( ref_optim_state = ref_state[name_to_index[name]] - offsets = optimizer.param_name_to_dp_rank_offsets[name][dist.get_rank(dpg.dp_pg)] + offsets = optimizer.param_name_to_dp_rank_offsets[name][dist.get_rank(parallel_context.dp_pg)] assert set(optim_state) == set(ref_optim_state) assert isinstance(param, NanotronParameter) @@ -477,7 +497,7 @@ def test_sliced_flat_tensor(): init_distributed(1, 1, 1)(_test_sliced_flat_tensor)() -def _test_sliced_flat_tensor(dpg: DistributedProcessGroups): +def _test_sliced_flat_tensor(parallel_context: ParallelContext): a = torch.randn(2, 3, requires_grad=True) grad = torch.randn(2, 3) a.grad = grad From 37116f3558f0f7d7293d2d28ceed25945280bcce Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 10 Jan 2024 12:17:57 +0000 Subject: [PATCH 3/9] change all arguments, variables to parallel_context --- run_generate.py | 58 ++--- run_train.py | 12 +- src/nanotron/core/parallel/model.py | 11 +- src/nanotron/core/parallel/tied_parameters.py | 33 +-- src/nanotron/core/random.py | 3 +- src/nanotron/dataloader.py | 53 ++-- src/nanotron/generate/generation.py | 105 ++++---- src/nanotron/helpers.py | 48 ++-- src/nanotron/models/base_model.py | 13 +- src/nanotron/models/falcon.py | 31 +-- src/nanotron/models/fast/falcon.py | 33 +-- src/nanotron/models/fast/gpt2.py | 41 ++-- src/nanotron/models/fast/llama.py | 24 +- src/nanotron/models/fast/starcoder2.py | 36 +-- src/nanotron/models/gpt2.py | 36 +-- src/nanotron/models/llama.py | 24 +- src/nanotron/serialize/main.py | 39 ++- src/nanotron/serialize/metadata.py | 16 +- src/nanotron/serialize/optimizer.py | 27 +- src/nanotron/serialize/random.py | 11 +- src/nanotron/serialize/utils.py | 8 +- src/nanotron/serialize/weights.py | 64 ++--- src/nanotron/trainer.py | 230 +++++++++++------- tests/helpers/dummy.py | 24 +- tests/helpers/utils.py | 4 +- tests/pytest.ini | 2 +- tests/test_clip_grads.py | 8 +- ..._parameters_accumulate_gradient_in_fp32.py | 8 +- tests/test_serialize.py | 52 ++-- tests/test_tie_weights.py | 18 +- tests/test_zero.py | 12 +- 31 files changed, 589 insertions(+), 495 deletions(-) diff --git a/run_generate.py b/run_generate.py index e3ef047a..7babaa46 100644 --- a/run_generate.py +++ b/run_generate.py @@ -14,13 +14,13 @@ - Benchmark: USE_BENCH=1 USE_FAST=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=2 run_benchmark2.py --pp 2 --tp 1 --dp 1 --model_name huggyllama/llama-7b --ckpt-path /admin/home/ferdinand_mom/.cache/huggingface/hub/models--HuggingFaceBR4--llama-7b-orig/snapshots/2160b3d0134a99d365851a7e95864b21e873e1c3 """ -import os import argparse +import os from pathlib import Path import torch from nanotron import logging -from nanotron.config import GenerationArgs, ParallelismArgs, LoggingArgs, get_config_from_file +from nanotron.config import GenerationArgs, LoggingArgs, ParallelismArgs, get_config_from_file from nanotron.core import distributed as dist from nanotron.core.parallel.parameters import sanity_check from nanotron.core.parallel.pipeline_parallelism.engine import ( @@ -28,19 +28,18 @@ ) from nanotron.core.parallel.pipeline_parallelism.tensor_pointer import TensorPointer from nanotron.core.parallel.tensor_parallelism.enum import TensorParallelLinearMode -from nanotron.core.process_groups import get_process_groups from nanotron.core.random import ( RandomStates, get_current_random_state, get_synced_random_state, set_random_seed, ) +from nanotron.distributed import ParallelContext from nanotron.generate.generation import ( GenerationInput, TokenizerConfig, greedy_search_text, ) - from nanotron.helpers import set_logger_verbosity_format from nanotron.logging import log_rank from nanotron.serialize import ( @@ -51,6 +50,7 @@ logger = logging.get_logger(__name__) + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--model_name", type=str, default=None, help="Model name") @@ -74,32 +74,31 @@ def main(): recompute_granularity=None, tp_linear_async_communication=True, ) - + logging_config = LoggingArgs( log_level="info", log_level_replica="info", ) - + dtype = torch.bfloat16 # Set random states set_random_seed(42) # Initialise all process groups - dpg = get_process_groups( + parallel_context = ParallelContext( data_parallel_size=parallel_config.dp, pipeline_parallel_size=parallel_config.pp, tensor_parallel_size=parallel_config.tp, ) # Set log levels - if dist.get_rank(dpg.world_pg) == 0: + if dist.get_rank(parallel_context.world_pg) == 0: if logging_config.log_level is not None: - set_logger_verbosity_format(logging_config.log_level, dpg=dpg) + set_logger_verbosity_format(logging_config.log_level, parallel_context=parallel_context) else: if logging_config.log_level_replica is not None: - set_logger_verbosity_format(logging_config.log_level_replica, dpg=dpg) - + set_logger_verbosity_format(logging_config.log_level_replica, parallel_context=parallel_context) tokenizer_path = args.model_name # if config.yaml in checkpoint path we use it @@ -118,7 +117,7 @@ def main(): assert args.model_name is not None, "model_name must be provided or config.yaml must be in checkpoint path" model_name = args.model_name model_config: AutoConfig = AutoConfig.from_pretrained(model_name) - + # model_config.num_hidden_layers = 1 log_rank(f"model_config: {model_config}", logger=logger, level=logging.INFO, rank=0) @@ -131,7 +130,7 @@ def main(): # Get synchronized random states if parallel_config.tp_mode is TensorParallelLinearMode.ALL_REDUCE: random_states = RandomStates( - {"tp_synced": get_synced_random_state(random_state=get_current_random_state(), pg=dpg.tp_pg)} + {"tp_synced": get_synced_random_state(random_state=get_current_random_state(), pg=parallel_context.tp_pg)} ) else: # We don't need to sync across TP when using sequence parallel (REDUCE_SCATTER) @@ -140,17 +139,17 @@ def main(): model = DistributedTrainer.build_model( model_builder=lambda: CONFIG_TO_MODEL_CLASS[model_config_cls]( config=model_config, - dpg=dpg, + parallel_context=parallel_context, parallel_config=parallel_config, random_states=random_states, ), dtype=dtype, - dpg=dpg, + parallel_context=parallel_context, ) # Mark some parameters as tied # TODO @nouamane: this is only needed for training, can we just mark params as NanotronParameter instead? - mark_tied_parameters(model=model, dpg=dpg, parallel_config=parallel_config) + mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) # Sanity check model sanity_check(root_module=model) @@ -163,7 +162,7 @@ def main(): level=logging.INFO, rank=0, ) - load_weights(model=model, dpg=dpg, root_folder=checkpoint_path) + load_weights(model=model, parallel_context=parallel_context, root_folder=checkpoint_path) model.eval() tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) @@ -184,7 +183,7 @@ def main(): "def fib(n)", # "This film was probably inspired by Godzilla", ] - + outputs = greedy_search_text( input_iter=(GenerationInput(text=text) for text in dummy_inputs), tokenizer=tokenizer, @@ -192,7 +191,7 @@ def main(): model=model.model, # TODO @thomasw21: Figure out how to pass p2p. p2p=model.model.p2p, - dpg=dpg, + parallel_context=parallel_context, max_new_tokens=args.max_new_tokens, max_micro_batch_size=2, generation_config=GenerationArgs(sampler="greedy", use_cache=False), @@ -204,7 +203,7 @@ def main(): ) dist.barrier() - + for output in outputs: input_ids = output.input_ids generated_ids = output.generation_ids @@ -212,28 +211,28 @@ def main(): assert isinstance(generated_ids, TensorPointer) continue assert isinstance(generated_ids, torch.Tensor) - + log_rank( f"input: {tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)[:1000]}", logger=logger, level=logging.INFO, rank=0, ) - + log_rank( f"generation: {tokenizer.decode(generated_ids[len(input_ids) :], clean_up_tokenization_spaces=False)}", logger=logger, level=logging.INFO, rank=0, ) - + log_rank( "--------------------------------------------------", logger=logger, level=logging.INFO, rank=0, ) - + if args.compare_with_no_cache: outputs = greedy_search_text( @@ -243,7 +242,7 @@ def main(): model=model.model, # TODO @thomasw21: Figure out how to pass p2p. p2p=model.model.p2p, - dpg=dpg, + parallel_context=parallel_context, max_new_tokens=args.max_new_tokens, max_micro_batch_size=2, generation_config=GenerationArgs(sampler="greedy", use_cache=True), @@ -255,7 +254,7 @@ def main(): ) dist.barrier() - + for output in outputs: input_ids = output.input_ids generated_ids = output.generation_ids @@ -263,21 +262,21 @@ def main(): assert isinstance(generated_ids, TensorPointer) continue assert isinstance(generated_ids, torch.Tensor) - + log_rank( f"input: {tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)[:1000]}", logger=logger, level=logging.INFO, rank=0, ) - + log_rank( f"generation: {tokenizer.decode(generated_ids[len(input_ids) :], clean_up_tokenization_spaces=False)}", logger=logger, level=logging.INFO, rank=0, ) - + log_rank( "--------------------------------------------------", logger=logger, @@ -285,5 +284,6 @@ def main(): rank=0, ) + if __name__ == "__main__": main() diff --git a/run_train.py b/run_train.py index b2f3c136..f291d3f0 100644 --- a/run_train.py +++ b/run_train.py @@ -56,7 +56,7 @@ def get_dataloader(trainer: DistributedTrainer, sanity_check_dataloader_interval output_pp_rank=output_pp_rank, vocab_size=trainer.model_config.vocab_size, seed=trainer.config.data.seed, - dpg=trainer.dpg, + parallel_context=trainer.parallel_context, )() elif isinstance(trainer.config.data.dataset, PretrainDatasetsArgs): log_rank("Using `datasets` library", logger=logger, level=logging.INFO, rank=0) @@ -64,7 +64,7 @@ def get_dataloader(trainer: DistributedTrainer, sanity_check_dataloader_interval tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" - with main_rank_first(trainer.dpg.world_pg): + with main_rank_first(trainer.parallel_context.world_pg): # 1st device processes dataset and cache it, then other devices load from cache # TODO @nouamanetazi: this may timeout before 1st device finishes processing dataset. Can we have a ctxmanager to modify timeout? # TODO: generalise to include for validation/test splits @@ -85,7 +85,7 @@ def get_dataloader(trainer: DistributedTrainer, sanity_check_dataloader_interval dataloader = get_train_dataloader( train_dataset=train_dataset, sequence_length=trainer.sequence_length, - dpg=trainer.dpg, + parallel_context=trainer.parallel_context, input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, micro_batch_size=trainer.micro_batch_size, @@ -97,9 +97,9 @@ def get_dataloader(trainer: DistributedTrainer, sanity_check_dataloader_interval # Check if we have enough samples for train_steps assert ( trainer.config.tokens.train_steps - trainer.start_iteration_step - ) * trainer.global_batch_size // trainer.dpg.dp_pg.size() < len(dataloader), ( - f"Dataset is too small for steps ({len(dataloader)} < {(trainer.config.tokens.train_steps - trainer.start_iteration_step) * trainer.global_batch_size // trainer.dpg.dp_pg.size()}), " - f"Try train_steps<={len(dataloader) * trainer.dpg.dp_pg.size() // trainer.global_batch_size + trainer.start_iteration_step}" + ) * trainer.global_batch_size // trainer.parallel_context.dp_pg.size() < len(dataloader), ( + f"Dataset is too small for steps ({len(dataloader)} < {(trainer.config.tokens.train_steps - trainer.start_iteration_step) * trainer.global_batch_size // trainer.parallel_context.dp_pg.size()}), " + f"Try train_steps<={len(dataloader) * trainer.parallel_context.dp_pg.size() // trainer.global_batch_size + trainer.start_iteration_step}" ) else: raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {trainer.config.data.dataset}") diff --git a/src/nanotron/core/parallel/model.py b/src/nanotron/core/parallel/model.py index c4ea18fc..1f26e5bf 100644 --- a/src/nanotron/core/parallel/model.py +++ b/src/nanotron/core/parallel/model.py @@ -1,19 +1,18 @@ -from torch import nn - from nanotron.core import distributed as dist from nanotron.core.parallel.tied_parameters import get_tied_id_to_param -from nanotron.core.process_groups import DistributedProcessGroups +from nanotron.distributed import ParallelContext +from torch import nn -def initial_sync(model: nn.Module, dpg: DistributedProcessGroups): +def initial_sync(model: nn.Module, parallel_context: ParallelContext): # Synchronize across dp: basic assumption sorted_name_params = sorted(model.named_parameters(), key=lambda x: x[0]) for name, param in sorted_name_params: - dist.all_reduce(param, op=dist.ReduceOp.AVG, group=dpg.dp_pg) + dist.all_reduce(param, op=dist.ReduceOp.AVG, group=parallel_context.dp_pg) # Synchronize across tied weights: basic assumption for (_, group_ranks), param in sorted( get_tied_id_to_param(parameters=model.parameters(), root_module=model).items(), key=lambda x: x[0] ): - group = dpg.world_ranks_to_pg[group_ranks] + group = parallel_context.world_ranks_to_pg[group_ranks] dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group) diff --git a/src/nanotron/core/parallel/tied_parameters.py b/src/nanotron/core/parallel/tied_parameters.py index 479f73e5..ab7b2e57 100644 --- a/src/nanotron/core/parallel/tied_parameters.py +++ b/src/nanotron/core/parallel/tied_parameters.py @@ -1,15 +1,14 @@ from collections import OrderedDict from typing import Dict, List, Optional, Tuple -from torch import nn - from nanotron.core import distributed as dist from nanotron.core import logging from nanotron.core.gradient_accumulator import GradientAccumulator from nanotron.core.logging import log_rank from nanotron.core.parallel.parameters import NanotronParameter -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.utils import get_parameter_and_parent_module +from nanotron.distributed import ParallelContext +from torch import nn logger = logging.get_logger(__name__) @@ -30,7 +29,7 @@ def create_tied_parameter( def tie_parameters( root_module: nn.Module, ties: List[Tuple[str, Tuple[int, ...]]], - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, reduce_op: Optional[dist.ReduceOp], ): """ @@ -40,7 +39,7 @@ def tie_parameters( :param root_module: nn.Module :param ties: List[Tuple[str, Tuple[int, ...]]]: a tie is (param_target, global_ranks) - :param dpg: DistributedProcessGroups + :param parallel_context: ParallelContext :return: """ if len(ties) < 1: @@ -49,7 +48,11 @@ def tie_parameters( # TODO @thomasw21: When we support Zero3 this isn't true anymore dp_ranks = tuple( sorted( - {dpg.get_3d_ranks(world_rank=global_rank)[1] for _, global_ranks in ties for global_rank in global_ranks} + { + parallel_context.get_3d_ranks(world_rank=global_rank)[1] + for _, global_ranks in ties + for global_rank in global_ranks + } ) ) assert ( @@ -60,7 +63,7 @@ def tie_parameters( global_ranks = tuple(sorted(set().union(*(tie[1] for tie in ties)))) new_param = None - world_rank = dist.get_rank(dpg.world_pg) + world_rank = dist.get_rank(parallel_context.world_pg) for tie_target, tie_model_ranks in ties: if world_rank not in tie_model_ranks: continue @@ -77,7 +80,7 @@ def tie_parameters( setattr(parent_module, param_name, new_param) -def create_pg_for_tied_weights(root_module: nn.Module, dpg: DistributedProcessGroups): +def create_pg_for_tied_weights(root_module: nn.Module, parallel_context: ParallelContext): """Tied weights are tied across specific set of global ranks, we use this method to create process groups for each difference set of global ranks""" group_ranks = { param.get_tied_info().global_ranks @@ -85,15 +88,15 @@ def create_pg_for_tied_weights(root_module: nn.Module, dpg: DistributedProcessGr if isinstance(param, NanotronParameter) and param.is_tied } - world_group_ranks = [None] * dpg.world_pg.size() - dist.all_gather_object(world_group_ranks, group_ranks, group=dpg.world_pg) + world_group_ranks = [None] * parallel_context.world_pg.size() + dist.all_gather_object(world_group_ranks, group_ranks, group=parallel_context.world_pg) all_group_ranks = sorted( set().union(*world_group_ranks), ) for global_ranks in all_group_ranks: - if global_ranks not in dpg.world_ranks_to_pg: - dpg.world_ranks_to_pg[global_ranks] = dist.new_group(global_ranks) + if global_ranks not in parallel_context.world_ranks_to_pg: + parallel_context.world_ranks_to_pg[global_ranks] = dist.new_group(global_ranks) def get_tied_id_to_param( @@ -114,7 +117,7 @@ def get_tied_id_to_param( def sync_tied_weights_gradients( module: nn.Module, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, grad_accumulator: Optional[GradientAccumulator], ): tied_id_to_param = get_tied_id_to_param( @@ -138,7 +141,7 @@ def sync_tied_weights_gradients( f"Syncing tied weights {name} across ranks {group_ranks} ...", logger=logger, level=logging.DEBUG, - group=dpg.world_ranks_to_pg[group_ranks], + group=parallel_context.world_ranks_to_pg[group_ranks], rank=0, ) key = (group_ranks, tied_info.reduce_op) @@ -148,4 +151,4 @@ def sync_tied_weights_gradients( group_ranks_and_reduce_op_to_tensors_to_reduce[(group_ranks, tied_info.reduce_op)] = [tied_grad] for (group_ranks, reduce_op), tensors in group_ranks_and_reduce_op_to_tensors_to_reduce.items(): - dist.all_reduce_coalesced(tensors=tensors, op=reduce_op, group=dpg.world_ranks_to_pg[group_ranks]) + dist.all_reduce_coalesced(tensors=tensors, op=reduce_op, group=parallel_context.world_ranks_to_pg[group_ranks]) diff --git a/src/nanotron/core/random.py b/src/nanotron/core/random.py index 98c4e78f..a9b445f2 100644 --- a/src/nanotron/core/random.py +++ b/src/nanotron/core/random.py @@ -5,7 +5,6 @@ import numpy as np import torch - from nanotron.core import distributed as dist from nanotron.core.distributed import ProcessGroup @@ -119,7 +118,7 @@ def branch_random_state(random_states: RandomStates, key: str, enabled: bool): try: yield finally: - # Update state from dpg with the newest state + # Update state from parallel_context with the newest state new_random_state = get_current_random_state() random_states[key] = new_random_state diff --git a/src/nanotron/dataloader.py b/src/nanotron/dataloader.py index d92cd3ce..95197f41 100644 --- a/src/nanotron/dataloader.py +++ b/src/nanotron/dataloader.py @@ -4,19 +4,18 @@ import numpy as np import torch -from torch.utils.data import BatchSampler, DataLoader -from torch.utils.data.distributed import DistributedSampler - from nanotron import logging from nanotron.config import Config from nanotron.core import distributed as dist from nanotron.core.parallel.pipeline_parallelism.tensor_pointer import TensorPointer -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.random import set_random_seed from nanotron.core.utils import ( assert_fail_except_rank_with, assert_tensor_synced_across_pg, ) +from nanotron.distributed import ParallelContext +from torch.utils.data import BatchSampler, DataLoader +from torch.utils.data.distributed import DistributedSampler try: import datasets @@ -33,7 +32,9 @@ def sanity_check_dataloader( - dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], dpg: DistributedProcessGroups, config: Config + dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], + parallel_context: ParallelContext, + config: Config, ) -> Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]: for batch in dataloader: micro_batch = { @@ -51,8 +52,10 @@ def sanity_check_dataloader( # It's fine if mask is the same across DP continue - with assert_fail_except_rank_with(AssertionError, rank_exception=0, pg=dpg.dp_pg): - assert_tensor_synced_across_pg(tensor=value, pg=dpg.dp_pg, msg=lambda err: f"{key} {err}") + with assert_fail_except_rank_with(AssertionError, rank_exception=0, pg=parallel_context.dp_pg): + assert_tensor_synced_across_pg( + tensor=value, pg=parallel_context.dp_pg, msg=lambda err: f"{key} {err}" + ) # SANITY CHECK: Check input are synchronized throughout TP for key, value in sorted(micro_batch.items(), key=lambda x: x[0]): @@ -60,7 +63,7 @@ def sanity_check_dataloader( continue assert_tensor_synced_across_pg( tensor=value, - pg=dpg.tp_pg, + pg=parallel_context.tp_pg, msg=lambda err: f"{key} are not synchronized throughout TP {err}", ) @@ -177,13 +180,15 @@ def dummy_infinite_data_generator( output_pp_rank: int, vocab_size: int, seed: int, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, ): def dummy_infinite_data_generator() -> Generator[Dict[str, Union[torch.Tensor, TensorPointer]], None, None]: # Random generator generator = torch.Generator(device="cuda") # Make sure that TP are synced always - generator.manual_seed(seed * (1 + dist.get_rank(dpg.dp_pg)) * (1 + dist.get_rank(dpg.pp_pg))) + generator.manual_seed( + seed * (1 + dist.get_rank(parallel_context.dp_pg)) * (1 + dist.get_rank(parallel_context.pp_pg)) + ) while True: yield { @@ -195,7 +200,7 @@ def dummy_infinite_data_generator() -> Generator[Dict[str, Union[torch.Tensor, T device="cuda", generator=generator, ) - if dist.get_rank(dpg.pp_pg) == input_pp_rank + if dist.get_rank(parallel_context.pp_pg) == input_pp_rank else TensorPointer(group_rank=input_pp_rank), "input_mask": torch.ones( micro_batch_size, @@ -203,7 +208,7 @@ def dummy_infinite_data_generator() -> Generator[Dict[str, Union[torch.Tensor, T dtype=torch.bool, device="cuda", ) - if dist.get_rank(dpg.pp_pg) == input_pp_rank + if dist.get_rank(parallel_context.pp_pg) == input_pp_rank else TensorPointer(group_rank=input_pp_rank), "label_ids": torch.randint( 0, @@ -213,7 +218,7 @@ def dummy_infinite_data_generator() -> Generator[Dict[str, Union[torch.Tensor, T device="cuda", generator=generator, ) - if dist.get_rank(dpg.pp_pg) == output_pp_rank + if dist.get_rank(parallel_context.pp_pg) == output_pp_rank else TensorPointer(group_rank=output_pp_rank), "label_mask": torch.ones( micro_batch_size, @@ -221,7 +226,7 @@ def dummy_infinite_data_generator() -> Generator[Dict[str, Union[torch.Tensor, T dtype=torch.bool, device="cuda", ) - if dist.get_rank(dpg.pp_pg) == output_pp_rank + if dist.get_rank(parallel_context.pp_pg) == output_pp_rank else TensorPointer(group_rank=output_pp_rank), } @@ -232,12 +237,12 @@ def dummy_infinite_data_generator() -> Generator[Dict[str, Union[torch.Tensor, T class SkipBatchSampler(BatchSampler): """ A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`. - Note that in case of DDP, we skip batches on each rank, so a total of `skip_batches * dpg.dp_pg.size()` batches + Note that in case of DDP, we skip batches on each rank, so a total of `skip_batches * parallel_context.dp_pg.size()` batches """ def __init__(self, batch_sampler: BatchSampler, skip_batches: int, dp_size: int): self.batch_sampler = batch_sampler - # In case of DDP, we skip batches on each rank, so a total of `skip_batches * dpg.dp_pg.size()` batches + # In case of DDP, we skip batches on each rank, so a total of `skip_batches * parallel_context.dp_pg.size()` batches self.skip_batches = skip_batches // dp_size def __iter__(self): @@ -323,11 +328,11 @@ class DataCollatorForCLM: sequence_length: int input_pp_rank: int output_pp_rank: int - dpg: DistributedProcessGroups + parallel_context: ParallelContext def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: # Process the case when "input_ids" doesn't exist - current_pp_rank = dist.get_rank(self.dpg.pp_pg) + current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) if current_pp_rank not in [ self.input_pp_rank, self.output_pp_rank, @@ -424,7 +429,7 @@ def _get_train_sampler( def get_train_dataloader( train_dataset: Dataset, sequence_length: int, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, input_pp_rank: int, output_pp_rank: int, micro_batch_size: int, @@ -439,7 +444,7 @@ def get_train_dataloader( raise ValueError(f"training requires a datasets.Dataset, but got {type(train_dataset)}") # Only some rank require to run the dataloader. - if dist.get_rank(dpg.pp_pg) not in [ + if dist.get_rank(parallel_context.pp_pg) not in [ input_pp_rank, output_pp_rank, ]: @@ -461,15 +466,15 @@ def get_train_dataloader( sequence_length=sequence_length, input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, - dpg=dpg, + parallel_context=parallel_context, ) # TODO @nouamanetazi: Remove unused columns: https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L852 # TODO @nouamanetazi: Support torch.utils.data.IterableDataset: https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L855-L872 train_sampler = _get_train_sampler( - dp_size=dpg.dp_pg.size(), - dp_rank=dist.get_rank(dpg.dp_pg), + dp_size=parallel_context.dp_pg.size(), + dp_rank=dist.get_rank(parallel_context.dp_pg), train_dataset=train_dataset, seed=seed_worker, use_loop_to_round_batch_size=use_loop_to_round_batch_size, @@ -486,7 +491,7 @@ def get_train_dataloader( drop_last=dataloader_drop_last, # we also drop_last in `clm_process()` num_workers=dataloader_num_workers, pin_memory=dataloader_pin_memory, - worker_init_fn=get_dataloader_worker_init(dp_rank=dist.get_rank(dpg.dp_pg)), + worker_init_fn=get_dataloader_worker_init(dp_rank=dist.get_rank(parallel_context.dp_pg)), # TODO @thomasw21: I'm not sure but this doesn't seem to work at all. # pin_memory_device="cuda", ) diff --git a/src/nanotron/generate/generation.py b/src/nanotron/generate/generation.py index ee0f6662..4493ea5c 100644 --- a/src/nanotron/generate/generation.py +++ b/src/nanotron/generate/generation.py @@ -4,9 +4,8 @@ from typing import Generator, Iterable, List, Optional, Tuple, Union import torch -from transformers import LlamaTokenizer - -from nanotron.config import GenerationArgs, BenchArgs +from nanotron import logging +from nanotron.config import BenchArgs, GenerationArgs from nanotron.core import distributed as dist from nanotron.core.distributed import ProcessGroup, get_global_rank from nanotron.core.parallel.pipeline_parallelism.block import get_min_max_rank @@ -14,16 +13,17 @@ from nanotron.core.parallel.pipeline_parallelism.p2p import P2P, TensorMetaData, view_as_contiguous from nanotron.core.parallel.pipeline_parallelism.state import PipelineEvalBatchState from nanotron.core.parallel.pipeline_parallelism.tensor_pointer import TensorPointer -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.utils import get_untyped_storage +from nanotron.distributed import ParallelContext from nanotron.generate.sampler import BasicSampler, GreedySampler, SamplerType, TopKSampler, TopPSampler +from nanotron.helpers import log_throughput from nanotron.models.generate_store import Store, attach_store from nanotron.models.llama import LlamaModel -from nanotron import logging -from nanotron.helpers import log_throughput +from transformers import LlamaTokenizer logger = logging.get_logger(__name__) + @dataclasses.dataclass class GenerationInput: text: str @@ -73,7 +73,7 @@ def micro_batcher( tokenizer: LlamaTokenizer, max_micro_batch_size: int, tokenizer_config: TokenizerConfig, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, input_rank: int, ) -> Generator[GenerationInputs, None, None]: """ @@ -91,11 +91,11 @@ def micro_batcher( # Empty micro batches don't matter return - if micro_batch_id % dpg.dp_pg.size() != dist.get_rank(dpg.dp_pg): + if micro_batch_id % parallel_context.dp_pg.size() != dist.get_rank(parallel_context.dp_pg): # Each dp is responsible for its own micro batches continue - if dist.get_rank(dpg.pp_pg) == input_rank: + if dist.get_rank(parallel_context.pp_pg) == input_rank: encodings = tokenizer( [elt.text for elt in micro_batch], return_tensors="pt", @@ -119,7 +119,7 @@ def micro_splitter( input_ids: torch.Tensor, input_mask: torch.Tensor, max_micro_batch_size: int, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, input_rank: int, ) -> Generator[GenerationInputs, None, None]: """ @@ -134,11 +134,11 @@ def micro_splitter( # Empty micro batches don't matter return - # if micro_batch_id % dpg.dp_pg.size() != dist.get_rank(dpg.dp_pg): + # if micro_batch_id % parallel_context.dp_pg.size() != dist.get_rank(parallel_context.dp_pg): # # Each dp is responsible for its own micro batches # continue - if dist.get_rank(dpg.pp_pg) == input_rank: + if dist.get_rank(parallel_context.pp_pg) == input_rank: micro_batch_mask = micro_batch_mask.to(dtype=torch.bool, device="cuda") micro_batch_mask.to("cuda") yield GenerationInputs(input_ids=micro_batch_ids.clone(), input_masks=micro_batch_mask.clone()) @@ -147,13 +147,14 @@ def micro_splitter( input_ids=TensorPointer(group_rank=input_rank), input_masks=TensorPointer(group_rank=input_rank) ) + @torch.inference_mode() def greedy_search_text( input_iter: Iterable[GenerationInput], tokenizer: LlamaTokenizer, model: LlamaModel, p2p: P2P, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, generation_config: GenerationArgs, tokenizer_config: Optional[TokenizerConfig], max_micro_batch_size: int, @@ -176,8 +177,8 @@ def greedy_search_text( sampler_type = SamplerType.GREEDY # Compute flag - is_decoder_input_rank = dist.get_rank(dpg.pp_pg) == decoder_input_rank - is_decoder_logit_rank = dist.get_rank(dpg.pp_pg) == decoder_logit_rank + is_decoder_input_rank = dist.get_rank(parallel_context.pp_pg) == decoder_input_rank + is_decoder_logit_rank = dist.get_rank(parallel_context.pp_pg) == decoder_logit_rank max_nb_microbatches = decoder_logit_rank - decoder_input_rank + 1 # TODO @thomasw21: Fix this as we shouldn't get P2P like that @@ -194,7 +195,7 @@ def greedy_search_text( max_micro_batch_size=max_micro_batch_size, tokenizer_config=tokenizer_config, input_rank=decoder_input_rank, - dpg=dpg, + parallel_context=parallel_context, ), chunk_size=max_nb_microbatches, ): @@ -222,13 +223,13 @@ def greedy_search_text( if is_bench: start_time, elapsed_time_first_iteration = time.perf_counter(), 0 - + for generation_iter in range(max_new_tokens): - + if is_bench and generation_iter == 0: torch.cuda.synchronize() elapsed_time_first_iteration = start_time - time.perf_counter() - + all_new_decoder_input_ids_and_mask_same_rank: List[ Tuple[Union[torch.LongTensor, TensorPointer], Union[torch.BoolTensor, TensorPointer]] ] = [] @@ -291,13 +292,13 @@ def greedy_search_text( # run a logit chooser. if sampler_type == SamplerType.GREEDY: - sampler = GreedySampler(pg=dpg.tp_pg) + sampler = GreedySampler(pg=parallel_context.tp_pg) elif sampler_type == SamplerType.TOP_K: - sampler = TopKSampler(pg=dpg.tp_pg) + sampler = TopKSampler(pg=parallel_context.tp_pg) elif sampler_type == SamplerType.TOP_P: - sampler = TopPSampler(pg=dpg.tp_pg) + sampler = TopPSampler(pg=parallel_context.tp_pg) elif sampler_type == SamplerType.BASIC: - sampler = BasicSampler(pg=dpg.tp_pg) + sampler = BasicSampler(pg=parallel_context.tp_pg) else: raise NotImplementedError(f"Sampler type {sampler_type} is not implemented") @@ -375,14 +376,14 @@ def generator(): new_decoder_states, all_new_decoder_input_ids_and_mask ) ) - - if is_bench: + + if is_bench: # Compute throughput (tok/s/gpu). Note that the first generation is done with full seq_len, so we don't count it. torch.cuda.synchronize() total_time_sec = time.perf_counter() - start_time - elapsed_time_first_iteration # We generate 1 token per iteration per batch (batch=microbatch) # Number of tokens generated every iteration: gbs/iteration_time - global_batch_size = len(batches) * dpg.dp_pg.size() + global_batch_size = len(batches) * parallel_context.dp_pg.size() tokens_per_sec = global_batch_size * max_new_tokens / total_time_sec model_tflops, hardware_tflops = model.get_flops_per_sec( @@ -396,18 +397,20 @@ def generator(): sequence_length=max_new_tokens, micro_batch_size=max_micro_batch_size, batch_accumulation_per_replica=1, - benchmark_csv_path="benchmark.csv" + benchmark_csv_path="benchmark.csv", + ) + + model_size = sum( + [p.numel() * p.data.element_size() for p in chain(model.parameters(), model.buffers())] ) - - model_size = sum([p.numel() * p.data.element_size() for p in chain(model.parameters(), model.buffers())]) log_throughput( bench_config, - dpg, + parallel_context, model_tflops, hardware_tflops, tokens_per_sec, - bandwidth = model_size * tokens_per_sec / 1e9 + bandwidth=model_size * tokens_per_sec / 1e9, ) # Flush communication @@ -435,17 +438,19 @@ def generator(): # Broadcast all data batch_generated_ids, batch_generated_mask = broadcast_tensors( - [batch_generated_ids, batch_generated_mask], group_src=decoder_input_rank, group=dpg.pp_pg + [batch_generated_ids, batch_generated_mask], + group_src=decoder_input_rank, + group=parallel_context.pp_pg, ) batch.input_ids, batch.input_masks = broadcast_tensors( - [batch.input_ids, batch.input_masks], group_src=decoder_input_rank, group=dpg.pp_pg + [batch.input_ids, batch.input_masks], group_src=decoder_input_rank, group=parallel_context.pp_pg ) # Flush the store to release memory state.store.flush() assert len(state.store) == 0 - if dist.get_rank(dpg.pp_pg) == decoder_input_rank: + if dist.get_rank(parallel_context.pp_pg) == decoder_input_rank: assert ( batch_generated_ids.shape[0] == batch.input_ids.shape[0] ), f"Batch size needs to match {batch_generated_ids.shape[0]} != {batch.input_ids.shape[0]}" @@ -458,7 +463,7 @@ def generator(): for i, (generated_ids, generated_mask) in enumerate(zip(batch_generated_ids, batch_generated_mask)): # TODO @thomasw21: We could actually have all ranks return the output, since it's been already broadcasted - if dist.get_rank(dpg.pp_pg) == decoder_input_rank: + if dist.get_rank(parallel_context.pp_pg) == decoder_input_rank: input_ids = batch.input_ids[i] input_mask = batch.input_masks[i] yield GenerationOutput( @@ -478,7 +483,7 @@ def greedy_search_tokenized( input_mask: torch.Tensor, model: LlamaModel, p2p: P2P, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, generation_config: GenerationArgs, max_micro_batch_size: int, max_new_tokens: int, @@ -503,8 +508,8 @@ def greedy_search_tokenized( decoder_input_rank, decoder_logit_rank = get_min_max_rank(module=model) # Compute flag - is_decoder_input_rank = dist.get_rank(dpg.pp_pg) == decoder_input_rank - is_decoder_logit_rank = dist.get_rank(dpg.pp_pg) == decoder_logit_rank + is_decoder_input_rank = dist.get_rank(parallel_context.pp_pg) == decoder_input_rank + is_decoder_logit_rank = dist.get_rank(parallel_context.pp_pg) == decoder_logit_rank max_nb_microbatches = decoder_logit_rank - decoder_input_rank + 1 # TODO @thomasw21: Fix this as we shouldn't get P2P like that @@ -519,7 +524,7 @@ def greedy_search_tokenized( input_ids, input_mask, max_micro_batch_size=max_micro_batch_size, - dpg=dpg, + parallel_context=parallel_context, input_rank=decoder_input_rank, ), chunk_size=max_nb_microbatches, @@ -597,17 +602,21 @@ def greedy_search_tokenized( # run a logit chooser. if sampler_type == SamplerType.GREEDY: - sampler = GreedySampler(pg=dpg.tp_pg) + sampler = GreedySampler(pg=parallel_context.tp_pg) elif sampler_type == SamplerType.TOP_K: sampler = TopKSampler( - pg=dpg.tp_pg, k=generation_config.top_k, temperature=generation_config.temperature + pg=parallel_context.tp_pg, + k=generation_config.top_k, + temperature=generation_config.temperature, ) elif sampler_type == SamplerType.TOP_P: sampler = TopPSampler( - pg=dpg.tp_pg, p=generation_config.top_p, temperature=generation_config.temperature + pg=parallel_context.tp_pg, + p=generation_config.top_p, + temperature=generation_config.temperature, ) elif sampler_type == SamplerType.BASIC: - sampler = BasicSampler(pg=dpg.tp_pg) + sampler = BasicSampler(pg=parallel_context.tp_pg) else: raise NotImplementedError(f"Sampler type {sampler_type} is not implemented") @@ -711,17 +720,19 @@ def generator(): # Broadcast all data batch_generated_ids, batch_generated_mask = broadcast_tensors( - [batch_generated_ids, batch_generated_mask], group_src=decoder_input_rank, group=dpg.pp_pg + [batch_generated_ids, batch_generated_mask], + group_src=decoder_input_rank, + group=parallel_context.pp_pg, ) batch.input_ids, batch.input_masks = broadcast_tensors( - [batch.input_ids, batch.input_masks], group_src=decoder_input_rank, group=dpg.pp_pg + [batch.input_ids, batch.input_masks], group_src=decoder_input_rank, group=parallel_context.pp_pg ) # Flush the store to release memory state.store.flush() assert len(state.store) == 0 - if dist.get_rank(dpg.pp_pg) == decoder_input_rank: + if dist.get_rank(parallel_context.pp_pg) == decoder_input_rank: assert ( batch_generated_ids.shape[0] == batch.input_ids.shape[0] ), f"Batch size needs to match {batch_generated_ids.shape[0]} != {batch.input_ids.shape[0]}" @@ -734,7 +745,7 @@ def generator(): for i, (generated_ids, generated_mask) in enumerate(zip(batch_generated_ids, batch_generated_mask)): # TODO @thomasw21: We could actually have all ranks return the output, since it's been already broadcasted - if dist.get_rank(dpg.pp_pg) == decoder_input_rank: + if dist.get_rank(parallel_context.pp_pg) == decoder_input_rank: input_ids = batch.input_ids[i] input_mask = batch.input_masks[i] yield GenerationOutput( diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 95d9a3c4..608b7273 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -36,12 +36,12 @@ from nanotron.core.parallel.tensor_parallelism.nn import ( TensorParallelLinearMode, ) -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.random import ( RandomStates, get_current_random_state, get_synced_random_state, ) +from nanotron.distributed import ParallelContext from nanotron.logging import LogItem, log_rank from torch import nn from torch.nn.parallel import DistributedDataParallel @@ -59,11 +59,11 @@ def get_args(): return parser.parse_args() -def set_logger_verbosity_format(logging_level: str, dpg: DistributedProcessGroups): +def set_logger_verbosity_format(logging_level: str, parallel_context: ParallelContext): node_name = os.environ.get("SLURMD_NODENAME") formatter = lg.Formatter( - fmt=f"%(asctime)s [%(levelname)s|DP={dist.get_rank(dpg.dp_pg)}|PP={dist.get_rank(dpg.pp_pg)}|" - f"TP={dist.get_rank(dpg.tp_pg)}{'|' + node_name if node_name else ''}]: %(message)s", + fmt=f"%(asctime)s [%(levelname)s|DP={dist.get_rank(parallel_context.dp_pg)}|PP={dist.get_rank(parallel_context.pp_pg)}|" + f"TP={dist.get_rank(parallel_context.tp_pg)}{'|' + node_name if node_name else ''}]: %(message)s", datefmt="%m/%d/%Y %H:%M:%S", ) # TODO @thomasw21: `logging.log_levels` returns valid lg log levels @@ -169,7 +169,7 @@ def lr_lambda(current_step: int): def init_optimizer_and_grad_accumulator( - model: nn.Module, optimizer_args: OptimizerArgs, dpg: DistributedProcessGroups + model: nn.Module, optimizer_args: OptimizerArgs, parallel_context: ParallelContext ) -> Tuple[BaseOptimizer, GradientAccumulator]: # Normalize DDP normalized_model = model.module if isinstance(model, DistributedDataParallel) else model @@ -234,7 +234,7 @@ def grad_optimizer_builder(named_param_groups): named_params_or_groups=named_parameters, # TODO @thomasw21: We need a better API for gradient accumulation/zero etc ... optimizer_builder=optimizer_builder, - dp_pg=dpg.dp_pg, + dp_pg=parallel_context.dp_pg, ) # SANITY CHECK: assert that optimizer's named_params point to model's params (check only the first one) @@ -259,7 +259,7 @@ def grad_optimizer_builder(named_param_groups): assert isinstance(grad_accumulator, FP32GradientAccumulator) grad_accumulator.assign_param_offsets( - dp_rank=dist.get_rank(dpg.dp_pg), + dp_rank=dist.get_rank(parallel_context.dp_pg), param_name_to_offsets=param_name_to_dp_rank_offsets, ) @@ -268,7 +268,7 @@ def grad_optimizer_builder(named_param_groups): assert isinstance(grad_accumulator, FP32GradientAccumulator) model.register_comm_hook( state=FP32GradBucketManager( - dp_pg=dpg.dp_pg, + dp_pg=parallel_context.dp_pg, accumulator=grad_accumulator, param_id_to_name={ id(param): param.get_tied_info().get_full_name_from_module_id_to_prefix( @@ -376,29 +376,29 @@ def op(lst, d=4, r=1): def test_all_pair_to_pair( - dpg: DistributedProcessGroups, throughput_size: int, throughput_iters: int, only_node_to_node: bool = True + parallel_context: ParallelContext, throughput_size: int, throughput_iters: int, only_node_to_node: bool = True ): """Test all pair-to-pair GPUs throughput Args: - dpg: DistributedProcessGroups + parallel_context: ParallelContext throughput_size: size of the tensor to send throughput_iters: number of warm-up iterations before testing the throughput only_node_to_node: if True, only test node-to-node throughput """ - comparisons = get_all_comps(dpg.world_pg.size()) - wr = dist.get_rank(dpg.world_pg) + comparisons = get_all_comps(parallel_context.world_pg.size()) + wr = dist.get_rank(parallel_context.world_pg) log_rank( f"[TEST] Testing throughput between {comparisons}", logger=logger, level=logging.WARNING, - group=dpg.world_pg, + group=parallel_context.world_pg, rank=0, ) for j, comp in enumerate(comparisons): - dist.barrier(group=dpg.world_pg) + dist.barrier(group=parallel_context.world_pg) for i, (a, b) in enumerate(comp): - dist.barrier(group=dpg.world_pg) + dist.barrier(group=parallel_context.world_pg) if wr not in [a, b]: continue if only_node_to_node and (a % 8 != 0 or b % 8 != 0): @@ -409,9 +409,9 @@ def test_all_pair_to_pair( pre = time.perf_counter() torch.cuda.synchronize() if wr == a: - dist.send(test_tensor, b, group=dpg.world_pg, tag=i + k) + dist.send(test_tensor, b, group=parallel_context.world_pg, tag=i + k) elif wr == b: - dist.recv(test_tensor, a, group=dpg.world_pg, tag=i + k) + dist.recv(test_tensor, a, group=parallel_context.world_pg, tag=i + k) torch.cuda.synchronize() duration = time.perf_counter() - pre del test_tensor @@ -422,21 +422,21 @@ def test_all_pair_to_pair( f"[TEST] {j, i, wr} Results throughput from {a} to {b}: {tput/1e9:.4f} Gbps", logger=logger, level=logging.WARNING, - group=dpg.world_pg, + group=parallel_context.world_pg, rank=None, ) log_rank( "[TEST] All comparisons done", logger=logger, level=logging.WARNING, - group=dpg.world_pg, + group=parallel_context.world_pg, rank=0, ) def log_throughput( config: Config, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, model_tflops=0, hardware_tflops=0, tokens_per_sec=0, @@ -444,20 +444,20 @@ def log_throughput( ): micro_batch_size = config.micro_batch_size n_micro_batches_per_batch = config.batch_accumulation_per_replica - global_batch_size = micro_batch_size * n_micro_batches_per_batch * dpg.dp_pg.size() + global_batch_size = micro_batch_size * n_micro_batches_per_batch * parallel_context.dp_pg.size() sequence_length = config.sequence_length slurm_job_id = os.environ.get("SLURM_JOB_ID", "N/A") csv_filename = config.benchmark_csv_path table_log = [ LogItem("model_name", config.model_name, "s"), - LogItem("nodes", math.ceil(dpg.world_pg.size() / 8), "d"), + LogItem("nodes", math.ceil(parallel_context.world_pg.size() / 8), "d"), LogItem("seq_len", (sequence_length), "d"), LogItem("mbs", micro_batch_size, "d"), LogItem("batch_accum", n_micro_batches_per_batch, "d"), LogItem("gbs", global_batch_size, "d"), LogItem("mTFLOPs", model_tflops, ".2f"), LogItem("hTFLOPs", hardware_tflops, ".2f"), - LogItem("tok/s/gpu", tokens_per_sec / dpg.world_pg.size(), ".2f"), + LogItem("tok/s/gpu", tokens_per_sec / parallel_context.world_pg.size(), ".2f"), LogItem("Bandwidth (GB/s)", bandwidth, ".2f"), LogItem("Mem Alloc (GB)", torch.cuda.max_memory_allocated() / 1024**3, ".2f"), LogItem("Mem Res (GB)", torch.cuda.max_memory_reserved() / 1024**3, ".2f"), @@ -484,7 +484,7 @@ def log_throughput( import csv - if dist.get_rank(dpg.world_pg) == 0: + if dist.get_rank(parallel_context.world_pg) == 0: if not os.path.exists(csv_filename): with open(csv_filename, mode="w") as fo: writer = csv.writer(fo) diff --git a/src/nanotron/models/base_model.py b/src/nanotron/models/base_model.py index 2c281802..0e6c5ae3 100644 --- a/src/nanotron/models/base_model.py +++ b/src/nanotron/models/base_model.py @@ -1,14 +1,13 @@ from abc import ABCMeta, abstractmethod from typing import Optional -from torch import nn -from transformers import AutoConfig - from nanotron.core import logging from nanotron.core.distributed import ProcessGroup from nanotron.core.logging import log_rank from nanotron.core.parallel.pipeline_parallelism.block import PipelineBlock -from nanotron.core.process_groups import DistributedProcessGroups +from nanotron.distributed import ParallelContext +from torch import nn +from transformers import AutoConfig logger = logging.get_logger(__name__) @@ -20,7 +19,7 @@ class NanotronModel(nn.Module, metaclass=ABCMeta): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.dpg: DistributedProcessGroups + self.parallel_context: ParallelContext self.config: AutoConfig # Attributes defined when building the model @@ -52,13 +51,13 @@ def after_optim_step_sanity_checks(self) -> None: pass def log_modules(self, level: int = logging.DEBUG, group: Optional[ProcessGroup] = None, rank: int = 0): - assert hasattr(self, "dpg"), "`NanotronModel` needs to have a `dpg` attribute" + assert hasattr(self, "parallel_context"), "`NanotronModel` needs to have a `parallel_context` attribute" for name, module in self.named_modules(): if not isinstance(module, PipelineBlock): continue log_rank( - f"module_name: {name} | PP: {module.rank}/{self.dpg.pp_pg.size()}", + f"module_name: {name} | PP: {module.rank}/{self.parallel_context.pp_pg.size()}", logger=logger, level=level, group=group, diff --git a/src/nanotron/models/falcon.py b/src/nanotron/models/falcon.py index f3e01293..6254f17a 100644 --- a/src/nanotron/models/falcon.py +++ b/src/nanotron/models/falcon.py @@ -21,9 +21,6 @@ from typing import Dict, Optional, Union import torch -from torch import nn -from torch.nn import functional as F - from nanotron.config import FalconConfig, ParallelismArgs, RecomputeGranularity from nanotron.core import distributed as dist from nanotron.core import logging @@ -37,10 +34,12 @@ TensorParallelLinearMode, TensorParallelRowLinear, ) -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.random import RandomStates from nanotron.core.utils import checkpoint_method +from nanotron.distributed import ParallelContext from nanotron.models import AttachableStore, NanotronModel +from torch import nn +from torch.nn import functional as F logger = logging.get_logger(__name__) @@ -548,16 +547,16 @@ class FalconModel(nn.Module): def __init__( self, config: FalconConfig, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, parallel_config: Optional["ParallelismArgs"], ): super().__init__() # Declare all the nodes - self.p2p = P2P(dpg.pp_pg, device=torch.device("cuda")) + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) self.config = config self.parallel_config = parallel_config - self.dpg = dpg + self.parallel_context = parallel_context self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE tp_linear_async_communication = ( parallel_config.tp_linear_async_communication if parallel_config is not None else False @@ -567,7 +566,7 @@ def __init__( p2p=self.p2p, module_builder=Embedding, module_kwargs={ - "tp_pg": dpg.tp_pg, + "tp_pg": parallel_context.tp_pg, "config": config, "parallel_config": parallel_config, }, @@ -583,7 +582,7 @@ def __init__( module_kwargs={ "config": config, "parallel_config": parallel_config, - "tp_pg": dpg.tp_pg, + "tp_pg": parallel_context.tp_pg, "layer_idx": layer_idx, }, module_input_keys={"hidden_states", "sequence_mask"}, @@ -608,7 +607,7 @@ def __init__( module_kwargs={ "in_features": config.hidden_size, "out_features": config.vocab_size, - "pg": dpg.tp_pg, + "pg": parallel_context.tp_pg, "bias": False, # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, @@ -673,7 +672,7 @@ def get_block_compute_costs(self): def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): """Get flops per second for a given model""" - world_size = self.dpg.world_pg.size() + world_size = self.parallel_context.world_pg.size() model_flops, hardware_flops = get_flops( num_layers=self.config.num_hidden_layers, hidden_size=self.config.hidden_size, @@ -725,17 +724,19 @@ class FalconForTraining(NanotronModel): def __init__( self, config: FalconConfig, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, parallel_config: Optional["ParallelismArgs"], random_states: Optional[RandomStates] = None, ): super().__init__() - self.transformer = FalconModel(config=config, dpg=dpg, parallel_config=parallel_config) + self.transformer = FalconModel( + config=config, parallel_context=parallel_context, parallel_config=parallel_config + ) self.loss = PipelineBlock( p2p=self.transformer.p2p, module_builder=Loss, - module_kwargs={"tp_pg": dpg.tp_pg}, + module_kwargs={"tp_pg": parallel_context.tp_pg}, module_input_keys={ "sharded_logits", "label_ids", @@ -743,7 +744,7 @@ def __init__( }, module_output_keys={"loss"}, ) - self.dpg = dpg + self.parallel_context = parallel_context self.config = config self.parallel_config = parallel_config diff --git a/src/nanotron/models/fast/falcon.py b/src/nanotron/models/fast/falcon.py index 56e9e6dc..ab25457c 100644 --- a/src/nanotron/models/fast/falcon.py +++ b/src/nanotron/models/fast/falcon.py @@ -21,10 +21,6 @@ import torch from flash_attn.flash_attn_interface import flash_attn_varlen_func -from torch import nn -from torch.nn import functional as F -from transformers import FalconConfig - from nanotron.config import ParallelismArgs, RecomputeGranularity from nanotron.core import distributed as dist from nanotron.core import logging @@ -38,10 +34,13 @@ TensorParallelLinearMode, TensorParallelRowLinear, ) -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.random import RandomStates from nanotron.core.utils import checkpoint_method +from nanotron.distributed import ParallelContext from nanotron.models import AttachableStore, NanotronModel +from torch import nn +from torch.nn import functional as F +from transformers import FalconConfig logger = logging.get_logger(__name__) @@ -490,16 +489,16 @@ class FalconModel(nn.Module): def __init__( self, config: FalconConfig, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], ): super().__init__() # Declare all the nodes - self.p2p = P2P(dpg.pp_pg, device=torch.device("cuda")) + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) self.config = config self.parallel_config = parallel_config - self.dpg = dpg + self.parallel_context = parallel_context self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE tp_linear_async_communication = ( parallel_config.tp_linear_async_communication if parallel_config is not None else False @@ -509,7 +508,7 @@ def __init__( p2p=self.p2p, module_builder=Embedding, module_kwargs={ - "tp_pg": dpg.tp_pg, + "tp_pg": parallel_context.tp_pg, "config": config, "parallel_config": parallel_config, }, @@ -525,7 +524,7 @@ def __init__( module_kwargs={ "config": config, "parallel_config": parallel_config, - "tp_pg": dpg.tp_pg, + "tp_pg": parallel_context.tp_pg, "layer_idx": layer_idx, }, module_input_keys={"hidden_states", "sequence_mask"}, @@ -550,7 +549,7 @@ def __init__( module_kwargs={ "in_features": config.hidden_size, "out_features": config.vocab_size, - "pg": dpg.tp_pg, + "pg": parallel_context.tp_pg, "bias": False, # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, @@ -615,7 +614,7 @@ def get_block_compute_costs(self): def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): """Get flops per second for a given model""" - world_size = self.dpg.world_pg.size() + world_size = self.parallel_context.world_pg.size() model_flops, hardware_flops = get_flops( num_layers=self.config.num_hidden_layers, hidden_size=self.config.hidden_size, @@ -667,17 +666,19 @@ class FalconForTraining(NanotronModel): def __init__( self, config: FalconConfig, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: Optional[RandomStates] = None, ): super().__init__() - self.transformer = FalconModel(config=config, dpg=dpg, parallel_config=parallel_config) + self.transformer = FalconModel( + config=config, parallel_context=parallel_context, parallel_config=parallel_config + ) self.loss = PipelineBlock( p2p=self.transformer.p2p, module_builder=Loss, - module_kwargs={"tp_pg": dpg.tp_pg}, + module_kwargs={"tp_pg": parallel_context.tp_pg}, module_input_keys={ "sharded_logits", "label_ids", @@ -685,7 +686,7 @@ def __init__( }, module_output_keys={"loss"}, ) - self.dpg = dpg + self.parallel_context = parallel_context self.config = config self.parallel_config = parallel_config diff --git a/src/nanotron/models/fast/gpt2.py b/src/nanotron/models/fast/gpt2.py index fc987d7a..94486ab7 100644 --- a/src/nanotron/models/fast/gpt2.py +++ b/src/nanotron/models/fast/gpt2.py @@ -19,15 +19,7 @@ from typing import Dict, Optional, Tuple, Union import torch -from torch.nn import LayerNorm -from nanotron.fused.layer_norm import TritonLayerNorm from flash_attn.flash_attn_interface import flash_attn_varlen_func -from torch import nn -from torch.nn import functional as F -from torch.nn import init -from transformers import GPTBigCodeConfig -from transformers.activations import ACT2FN - from nanotron.config import ParallelismArgs, RecomputeGranularity from nanotron.core import distributed as dist from nanotron.core.distributed import get_global_rank @@ -44,10 +36,16 @@ TensorParallelRowLinear, ) from nanotron.core.parallel.tied_parameters import create_tied_parameter -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.random import RandomStates, branch_random_state from nanotron.core.utils import checkpoint_method +from nanotron.distributed import ParallelContext +from nanotron.fused.layer_norm import TritonLayerNorm from nanotron.models import AttachableStore, NanotronModel +from torch import nn +from torch.nn import LayerNorm, init +from torch.nn import functional as F +from transformers import GPTBigCodeConfig +from transformers.activations import ACT2FN class MLP(nn.Module): @@ -788,14 +786,14 @@ class GPTModel(nn.Module): def __init__( self, config: GPTBigCodeConfig, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: RandomStates, ): super().__init__() # Declare all the nodes - self.p2p = P2P(dpg.pp_pg, device=torch.device("cuda")) + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) self.random_states = random_states self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE @@ -803,7 +801,7 @@ def __init__( p2p=self.p2p, module_builder=Embedding, module_kwargs={ - "tp_pg": dpg.tp_pg, + "tp_pg": parallel_context.tp_pg, "config": config, "parallel_config": parallel_config, }, @@ -827,7 +825,7 @@ def __init__( module_kwargs={ "config": config, "parallel_config": parallel_config, - "tp_pg": dpg.tp_pg, + "tp_pg": parallel_context.tp_pg, "random_states": random_states, "layer_idx": layer_idx, }, @@ -853,7 +851,7 @@ def __init__( module_kwargs={ "in_features": config.hidden_size, "out_features": config.vocab_size, - "pg": dpg.tp_pg, + "pg": parallel_context.tp_pg, "bias": False, # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, @@ -933,16 +931,21 @@ class GPTForTraining(NanotronModel): def __init__( self, config: GPTBigCodeConfig, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: RandomStates, ): super().__init__() - self.model = GPTModel(config=config, dpg=dpg, parallel_config=parallel_config, random_states=random_states) + self.model = GPTModel( + config=config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=random_states, + ) self.loss = PipelineBlock( p2p=self.model.p2p, module_builder=Loss, - module_kwargs={"tp_pg": dpg.tp_pg}, + module_kwargs={"tp_pg": parallel_context.tp_pg}, module_input_keys={ "sharded_logits", "label_ids", @@ -952,7 +955,7 @@ def __init__( ) self.config: GPTBigCodeConfig = config self.parallel_config = parallel_config - self.dpg = dpg + self.parallel_context = parallel_context def forward( self, @@ -1168,7 +1171,7 @@ def get_block_compute_costs(self): def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): """Get flops per second for a given model""" - world_size = self.dpg.world_pg.size() + world_size = self.parallel_context.world_pg.size() model_flops, hardware_flops = get_flops( num_layers=self.config.num_hidden_layers, hidden_size=self.config.hidden_size, diff --git a/src/nanotron/models/fast/llama.py b/src/nanotron/models/fast/llama.py index 618a5aeb..37b8f3fd 100644 --- a/src/nanotron/models/fast/llama.py +++ b/src/nanotron/models/fast/llama.py @@ -40,9 +40,9 @@ TensorParallelLinearMode, TensorParallelRowLinear, ) -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.random import RandomStates from nanotron.core.utils import checkpoint_method +from nanotron.distributed import ParallelContext from nanotron.fused.layer_norm import TritonRMSNorm from nanotron.models import AttachableStore, NanotronModel from torch import nn @@ -612,16 +612,16 @@ class LlamaModel(nn.Module): def __init__( self, config: LlamaConfig, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], ): super().__init__() # Declare all the nodes - self.p2p = P2P(dpg.pp_pg, device=torch.device("cuda")) + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) self.config = config self.parallel_config = parallel_config - self.dpg = dpg + self.parallel_context = parallel_context self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE tp_linear_async_communication = ( parallel_config.tp_linear_async_communication if parallel_config is not None else False @@ -631,7 +631,7 @@ def __init__( p2p=self.p2p, module_builder=Embedding, module_kwargs={ - "tp_pg": dpg.tp_pg, + "tp_pg": parallel_context.tp_pg, "config": config, "parallel_config": parallel_config, }, @@ -647,7 +647,7 @@ def __init__( module_kwargs={ "config": config, "parallel_config": parallel_config, - "tp_pg": dpg.tp_pg, + "tp_pg": parallel_context.tp_pg, "layer_idx": layer_idx, }, module_input_keys={"hidden_states", "sequence_mask"}, @@ -672,7 +672,7 @@ def __init__( module_kwargs={ "in_features": config.hidden_size, "out_features": config.vocab_size, - "pg": dpg.tp_pg, + "pg": parallel_context.tp_pg, "bias": False, # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, @@ -737,7 +737,7 @@ def get_block_compute_costs(self): def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): """Get flops per second for a given model""" - world_size = self.dpg.world_pg.size() + world_size = self.parallel_context.world_pg.size() try: num_key_values_heads = self.config.num_key_value_heads except AttributeError: @@ -793,16 +793,16 @@ class LlamaForTraining(NanotronModel): def __init__( self, config: LlamaConfig, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: Optional[RandomStates] = None, ): super().__init__() - self.model = LlamaModel(config=config, dpg=dpg, parallel_config=parallel_config) + self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config) self.loss = PipelineBlock( p2p=self.model.p2p, module_builder=Loss, - module_kwargs={"tp_pg": dpg.tp_pg}, + module_kwargs={"tp_pg": parallel_context.tp_pg}, module_input_keys={ "sharded_logits", "label_ids", @@ -810,7 +810,7 @@ def __init__( }, module_output_keys={"loss"}, ) - self.dpg = dpg + self.parallel_context = parallel_context self.config = config self.parallel_config = parallel_config diff --git a/src/nanotron/models/fast/starcoder2.py b/src/nanotron/models/fast/starcoder2.py index b7eb2d95..13c9adda 100644 --- a/src/nanotron/models/fast/starcoder2.py +++ b/src/nanotron/models/fast/starcoder2.py @@ -30,11 +30,6 @@ flash_attn_varlen_func, flash_attn_with_kvcache, ) -from torch import nn -from torch.nn import LayerNorm, init -from torch.nn import functional as F -from transformers.activations import ACT2FN - from nanotron.config import ParallelismArgs, RecomputeGranularity, Starcoder2Config from nanotron.core import distributed as dist from nanotron.core.distributed import get_global_rank @@ -51,11 +46,15 @@ TensorParallelRowLinear, ) from nanotron.core.parallel.tied_parameters import create_tied_parameter -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.random import RandomStates, branch_random_state from nanotron.core.utils import checkpoint_method +from nanotron.distributed import ParallelContext from nanotron.fused.layer_norm import TritonLayerNorm from nanotron.models import AttachableStore, NanotronModel +from torch import nn +from torch.nn import LayerNorm, init +from torch.nn import functional as F +from transformers.activations import ACT2FN _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_varlen_func).parameters) @@ -1257,14 +1256,14 @@ class GPTModel(nn.Module): def __init__( self, config: Starcoder2Config, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: RandomStates, ): super().__init__() # Declare all the nodes - self.p2p = P2P(dpg.pp_pg, device=torch.device("cuda")) + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) self.random_states = random_states self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE @@ -1272,7 +1271,7 @@ def __init__( p2p=self.p2p, module_builder=Embedding, module_kwargs={ - "tp_pg": dpg.tp_pg, + "tp_pg": parallel_context.tp_pg, "config": config, "parallel_config": parallel_config, }, @@ -1296,7 +1295,7 @@ def __init__( module_kwargs={ "config": config, "parallel_config": parallel_config, - "tp_pg": dpg.tp_pg, + "tp_pg": parallel_context.tp_pg, "random_states": random_states, "layer_idx": layer_idx, }, @@ -1322,7 +1321,7 @@ def __init__( module_kwargs={ "in_features": config.hidden_size, "out_features": config.vocab_size, - "pg": dpg.tp_pg, + "pg": parallel_context.tp_pg, "bias": False, # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, @@ -1404,16 +1403,21 @@ class Starcoder2ForTraining(NanotronModel): def __init__( self, config: Starcoder2Config, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: RandomStates, ): super().__init__() - self.model = GPTModel(config=config, dpg=dpg, parallel_config=parallel_config, random_states=random_states) + self.model = GPTModel( + config=config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=random_states, + ) self.loss = PipelineBlock( p2p=self.model.p2p, module_builder=Loss, - module_kwargs={"tp_pg": dpg.tp_pg}, + module_kwargs={"tp_pg": parallel_context.tp_pg}, module_input_keys={ "sharded_logits", "label_ids", @@ -1423,7 +1427,7 @@ def __init__( ) self.config: Starcoder2Config = config self.parallel_config = parallel_config - self.dpg = dpg + self.parallel_context = parallel_context def forward( self, @@ -1639,7 +1643,7 @@ def get_block_compute_costs(self): def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): """Get flops per second for a given model""" - world_size = self.dpg.world_pg.size() + world_size = self.parallel_context.world_pg.size() model_flops, hardware_flops = get_flops( num_layers=self.config.num_hidden_layers, hidden_size=self.config.hidden_size, diff --git a/src/nanotron/models/gpt2.py b/src/nanotron/models/gpt2.py index 4b636820..e26f007f 100644 --- a/src/nanotron/models/gpt2.py +++ b/src/nanotron/models/gpt2.py @@ -19,11 +19,6 @@ from typing import Dict, Optional, Tuple, Union import torch -from torch import nn -from torch.nn import LayerNorm -from transformers import GPTBigCodeConfig -from transformers.activations import ACT2FN - from nanotron.config import ParallelismArgs, RecomputeGranularity from nanotron.core import distributed as dist from nanotron.core.distributed import get_global_rank @@ -44,10 +39,14 @@ TensorParallelRowLinear, ) from nanotron.core.parallel.tied_parameters import create_tied_parameter -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.random import RandomStates, branch_random_state from nanotron.core.utils import checkpoint_method +from nanotron.distributed import ParallelContext from nanotron.models import AttachableStore, NanotronModel +from torch import nn +from torch.nn import LayerNorm +from transformers import GPTBigCodeConfig +from transformers.activations import ACT2FN class MLP(nn.Module): @@ -493,14 +492,14 @@ class GPTModel(nn.Module): def __init__( self, config: GPTBigCodeConfig, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: RandomStates, ): super().__init__() # Declare all the nodes - self.p2p = P2P(dpg.pp_pg, device=torch.device("cuda")) + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) self.random_states = random_states self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE @@ -508,7 +507,7 @@ def __init__( p2p=self.p2p, module_builder=Embedding, module_kwargs={ - "tp_pg": dpg.tp_pg, + "tp_pg": parallel_context.tp_pg, "config": config, "parallel_config": parallel_config, }, @@ -532,7 +531,7 @@ def __init__( module_kwargs={ "config": config, "parallel_config": parallel_config, - "tp_pg": dpg.tp_pg, + "tp_pg": parallel_context.tp_pg, "random_states": random_states, "layer_idx": layer_idx, }, @@ -558,7 +557,7 @@ def __init__( module_kwargs={ "in_features": config.hidden_size, "out_features": config.vocab_size, - "pg": dpg.tp_pg, + "pg": parallel_context.tp_pg, "bias": False, # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, @@ -632,16 +631,21 @@ class GPTForTraining(NanotronModel): def __init__( self, config: GPTBigCodeConfig, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: RandomStates, ): super().__init__() - self.model = GPTModel(config=config, dpg=dpg, parallel_config=parallel_config, random_states=random_states) + self.model = GPTModel( + config=config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=random_states, + ) self.loss = PipelineBlock( p2p=self.model.p2p, module_builder=Loss, - module_kwargs={"tp_pg": dpg.tp_pg}, + module_kwargs={"tp_pg": parallel_context.tp_pg}, module_input_keys={ "sharded_logits", "label_ids", @@ -651,7 +655,7 @@ def __init__( ) self.config = config self.parallel_config = parallel_config - self.dpg = dpg + self.parallel_context = parallel_context def forward( self, @@ -861,7 +865,7 @@ def get_block_compute_costs(self): def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): """Get flops per second for a given model""" - world_size = self.dpg.world_pg.size() + world_size = self.parallel_context.world_pg.size() model_flops, hardware_flops = get_flops( num_layers=self.config.num_hidden_layers, hidden_size=self.config.hidden_size, diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 9d865267..5ce50feb 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -32,9 +32,9 @@ TensorParallelLinearMode, TensorParallelRowLinear, ) -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.random import RandomStates from nanotron.core.utils import checkpoint_method +from nanotron.distributed import ParallelContext from nanotron.models import AttachableStore, NanotronModel from torch import nn from transformers.activations import ACT2FN @@ -571,16 +571,16 @@ class LlamaModel(nn.Module): def __init__( self, config: LlamaConfig, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, parallel_config: Optional["ParallelismArgs"], ): super().__init__() # Declare all the nodes - self.p2p = P2P(dpg.pp_pg, device=torch.device("cuda")) + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) self.config = config self.parallel_config = parallel_config - self.dpg = dpg + self.parallel_context = parallel_context self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE tp_linear_async_communication = ( parallel_config.tp_linear_async_communication if parallel_config is not None else False @@ -590,7 +590,7 @@ def __init__( p2p=self.p2p, module_builder=Embedding, module_kwargs={ - "tp_pg": dpg.tp_pg, + "tp_pg": parallel_context.tp_pg, "config": config, "parallel_config": parallel_config, }, @@ -606,7 +606,7 @@ def __init__( module_kwargs={ "config": config, "parallel_config": parallel_config, - "tp_pg": dpg.tp_pg, + "tp_pg": parallel_context.tp_pg, "layer_idx": layer_idx, }, module_input_keys={"hidden_states", "sequence_mask"}, @@ -631,7 +631,7 @@ def __init__( module_kwargs={ "in_features": config.hidden_size, "out_features": config.vocab_size, - "pg": dpg.tp_pg, + "pg": parallel_context.tp_pg, "bias": False, # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, @@ -696,7 +696,7 @@ def get_block_compute_costs(self): def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): """Get flops per second for a given model""" - world_size = self.dpg.world_pg.size() + world_size = self.parallel_context.world_pg.size() try: num_key_values_heads = self.config.num_key_value_heads except AttributeError: @@ -752,16 +752,16 @@ class LlamaForTraining(NanotronModel): def __init__( self, config: LlamaConfig, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, parallel_config: Optional["ParallelismArgs"], random_states: Optional[RandomStates] = None, ): super().__init__() - self.model = LlamaModel(config=config, dpg=dpg, parallel_config=parallel_config) + self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config) self.loss = PipelineBlock( p2p=self.model.p2p, module_builder=Loss, - module_kwargs={"tp_pg": dpg.tp_pg}, + module_kwargs={"tp_pg": parallel_context.tp_pg}, module_input_keys={ "sharded_logits", "label_ids", @@ -769,7 +769,7 @@ def __init__( }, module_output_keys={"loss"}, ) - self.dpg = dpg + self.parallel_context = parallel_context self.config = config self.parallel_config = parallel_config diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 7caad0e8..11f68a38 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -2,21 +2,20 @@ from typing import Optional import torch -from torch import nn -from torch.nn.parallel import DistributedDataParallel - from nanotron import logging from nanotron.config import Config from nanotron.core import distributed as dist from nanotron.core import optim as optim from nanotron.core.distributed import get_global_rank from nanotron.core.parallel.parameters import NanotronParameter -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.utils import assert_tensor_synced_across_pg +from nanotron.distributed import ParallelContext from nanotron.logging import log_rank from nanotron.serialize.metadata import CheckpointMetadata, load_meta, save_meta from nanotron.serialize.optimizer import load_lr_scheduler, load_optimizer, save_lr_scheduler, save_optimizer from nanotron.serialize.weights import load_weights, save_weights +from torch import nn +from torch.nn.parallel import DistributedDataParallel """ We're going to use safetensors. The reason is that loading segments is going to be much easier @@ -42,7 +41,7 @@ def save( model: nn.Module, optimizer: optim.BaseOptimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, root_folder: Path, should_save_config: bool = True, should_save_model: bool = True, @@ -62,13 +61,13 @@ def save( raise e try: if should_save_model: - save_weights(model=model, dpg=dpg, root_folder=root_folder) + save_weights(model=model, parallel_context=parallel_context, root_folder=root_folder) except Exception as e: print(f"Error while saving weights checkpoint: {e}") raise e try: if should_save_optimizer: - save_optimizer(optimizer=optimizer, dpg=dpg, root_folder=root_folder) + save_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=root_folder) except Exception as e: print(f"Error while saving optimizer checkpoint: {e}") raise e @@ -76,21 +75,21 @@ def save( if should_save_lr_scheduler: save_lr_scheduler( lr_scheduler=lr_scheduler, - dpg=dpg, + parallel_context=parallel_context, root_folder=root_folder, ) except Exception as e: print(f"Error while saving lr_scheduler checkpoint: {e}") raise e - save_meta(root_folder=root_folder, dpg=dpg, checkpoint_metadata=checkpoint_metadata) + save_meta(root_folder=root_folder, parallel_context=parallel_context, checkpoint_metadata=checkpoint_metadata) # TODO @thomas21: sanity check, not sure whether that needs to happen at testing or now (depends how much it costs) ### - # SANITY CHECK: Check that the model params are synchronized across `dpg.dp_pg` + # SANITY CHECK: Check that the model params are synchronized across `parallel_context.dp_pg` for name, param_or_buffer in sorted(model.state_dict().items(), key=lambda x: x[0]): assert_tensor_synced_across_pg( - tensor=param_or_buffer, pg=dpg.dp_pg, msg=lambda err: f"{name} are not synced across DP {err}" + tensor=param_or_buffer, pg=parallel_context.dp_pg, msg=lambda err: f"{name} are not synced across DP {err}" ) # SANITY CHECK: Check that the tied parameters are synchronized @@ -106,13 +105,13 @@ def save( for tied_param in sorted_tied_parameters: tied_info = tied_param.get_tied_info() group_ranks = tied_info.global_ranks - group = dpg.world_ranks_to_pg[group_ranks] + group = parallel_context.world_ranks_to_pg[group_ranks] assert_tensor_synced_across_pg( tensor=tied_param, pg=group, msg=lambda err: f"Tied {tied_info.name} are not synced {err}" ) if not optimizer.inherit_from(optim.ZeroDistributedOptimizer): - # SANITY CHECK: Check that the optimizer state are synchronized across `dpg.dp_pg` + # SANITY CHECK: Check that the optimizer state are synchronized across `parallel_context.dp_pg` for id_, optim_state in sorted(optimizer.state_dict()["state"].items(), key=lambda x: x[0]): for name, tensor in optim_state.items(): if name == "step": @@ -120,7 +119,7 @@ def save( tensor = tensor.to("cuda") assert_tensor_synced_across_pg( - tensor=tensor, pg=dpg.dp_pg, msg=lambda err: f"{name} are not synced across DP {err}" + tensor=tensor, pg=parallel_context.dp_pg, msg=lambda err: f"{name} are not synced across DP {err}" ) # SANITY CHECK: tied parameters have their optimizer states synchronized @@ -146,7 +145,7 @@ def save( continue tied_info = param.get_tied_info() group_ranks = tied_info.global_ranks - group = dpg.world_ranks_to_pg[group_ranks] + group = parallel_context.world_ranks_to_pg[group_ranks] reference_rank = 0 current_rank = dist.get_rank(group) @@ -172,14 +171,14 @@ def save( ) ### - dist.barrier(dpg.world_pg) + dist.barrier(parallel_context.world_pg) def load( model: nn.Module, optimizer: optim.BaseOptimizer, lr_scheduler, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, root_folder: Path, ) -> CheckpointMetadata: """ @@ -189,8 +188,8 @@ def load( :param filepath: Path :return: """ - checkpoint_metadata = load_meta(dpg=dpg, root_folder=root_folder) - load_weights(model=model, dpg=dpg, root_folder=root_folder) + checkpoint_metadata = load_meta(parallel_context=parallel_context, root_folder=root_folder) + load_weights(model=model, parallel_context=parallel_context, root_folder=root_folder) # SANITY CHECK: assert that optimizer's named_params still point to model's params (check only the first one) if isinstance(optimizer, optim.ZeroDistributedOptimizer): @@ -204,7 +203,7 @@ def load( param = next(p for n, p in model.named_parameters() if n == optim_model_param_name) assert param.data_ptr() == optim_model_param.data_ptr() - load_optimizer(optimizer=optimizer, dpg=dpg, root_folder=root_folder) + load_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=root_folder) load_lr_scheduler( lr_scheduler=lr_scheduler, root_folder=root_folder, diff --git a/src/nanotron/serialize/metadata.py b/src/nanotron/serialize/metadata.py index 92ff456c..2fd2d86e 100644 --- a/src/nanotron/serialize/metadata.py +++ b/src/nanotron/serialize/metadata.py @@ -6,12 +6,11 @@ import dacite import torch from dacite import from_dict -from packaging.version import Version - from nanotron.constants import CHECKPOINT_VERSION from nanotron.core import distributed as dist from nanotron.core.parallel.parameters import SlicesPair -from nanotron.core.process_groups import DistributedProcessGroups +from nanotron.distributed import ParallelContext +from packaging.version import Version @dataclasses.dataclass @@ -113,13 +112,16 @@ def to_list(list_: Union[List, Tuple], type_hooks: Dict[Type, Callable[[Any], An return list_.__class__((process_type(elt, type_hooks=type_hooks) for elt in list_)) -def save_meta(dpg: DistributedProcessGroups, root_folder: Path, checkpoint_metadata: dict): - if dist.get_rank(dpg.world_pg) != 0: +def save_meta(parallel_context: ParallelContext, root_folder: Path, checkpoint_metadata: dict): + if dist.get_rank(parallel_context.world_pg) != 0: return root_folder.mkdir(exist_ok=True, parents=True) checkpoint_metadata = CheckpointMetadata( - version=CHECKPOINT_VERSION, tp=dpg.tp_pg.size(), dp=dpg.dp_pg.size(), metas=checkpoint_metadata + version=CHECKPOINT_VERSION, + tp=parallel_context.tp_pg.size(), + dp=parallel_context.dp_pg.size(), + metas=checkpoint_metadata, ) # There are some types that require manual casting in order to work correctly. @@ -129,7 +131,7 @@ def save_meta(dpg: DistributedProcessGroups, root_folder: Path, checkpoint_metad json.dump(processed_metadata, fo, indent=2, sort_keys=True) -def load_meta(dpg: DistributedProcessGroups, root_folder: Path) -> CheckpointMetadata: +def load_meta(parallel_context: ParallelContext, root_folder: Path) -> CheckpointMetadata: with open(root_folder / "checkpoint_metadata.json", mode="r") as fi: checkpoint_metadata = json.load(fi) checkpoint_metadata = from_dict( diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index dc07e53e..077ecf64 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -3,18 +3,17 @@ from typing import Optional import torch - from nanotron.core import distributed as dist from nanotron.core import optim as optim -from nanotron.core.process_groups import DistributedProcessGroups +from nanotron.distributed import ParallelContext from nanotron.serialize.utils import ObjectType -def optimizer_filename(dpg: DistributedProcessGroups, is_zero: bool): +def optimizer_filename(parallel_context: ParallelContext, is_zero: bool): if is_zero is True: - return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(dpg.pp_pg)}-of-{dpg.pp_pg.size()}_dp-{dist.get_rank(dpg.dp_pg)}-of-{dpg.dp_pg.size()}_tp-{dist.get_rank(dpg.tp_pg)}-of-{dpg.tp_pg.size()}.pt" + return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}.pt" else: - return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(dpg.pp_pg)}-of-{dpg.pp_pg.size()}_tp-{dist.get_rank(dpg.tp_pg)}-of-{dpg.tp_pg.size()}.pt" + return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}.pt" def lr_scheduler_filename(): @@ -24,7 +23,7 @@ def lr_scheduler_filename(): def save_optimizer( optimizer: optim.BaseOptimizer, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, root_folder: Path, ): """Saves optimizer states @@ -36,28 +35,29 @@ def save_optimizer( root_folder = root_folder / "optimizer" root_folder.mkdir(exist_ok=True, parents=True) - if dist.get_rank(dpg.world_pg) == 0: + if dist.get_rank(parallel_context.world_pg) == 0: with open(root_folder / "optimizer_config.json", "w") as fo: json.dump({"type": optimizer.__class__.__name__}, fo) - if (not optimizer.inherit_from(optim.ZeroDistributedOptimizer)) and dist.get_rank(dpg.dp_pg) > 0: + if (not optimizer.inherit_from(optim.ZeroDistributedOptimizer)) and dist.get_rank(parallel_context.dp_pg) > 0: # this is Zero-0, so only DP-0 saves the optimizer states return # We dump the optimizer state using `torch.save` torch.save( optimizer.state_dict(), - root_folder / optimizer_filename(dpg, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)), + root_folder + / optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)), ) def save_lr_scheduler( lr_scheduler, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, root_folder: Path, ): """Saves lr scheduler states""" - if dist.get_rank(dpg.world_pg) > 0: + if dist.get_rank(parallel_context.world_pg) > 0: # Only WORLD-RANK 0 saves the lr scheduler state return @@ -73,7 +73,7 @@ def save_lr_scheduler( def load_optimizer( optimizer: optim.BaseOptimizer, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, root_folder: Path, map_location: Optional[str] = None, ): @@ -83,7 +83,8 @@ def load_optimizer( # TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely state_dict = torch.load( - root_folder / optimizer_filename(dpg, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)), + root_folder + / optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)), map_location=map_location, ) optimizer.load_state_dict(state_dict) diff --git a/src/nanotron/serialize/random.py b/src/nanotron/serialize/random.py index 80b23322..015c5e79 100644 --- a/src/nanotron/serialize/random.py +++ b/src/nanotron/serialize/random.py @@ -1,22 +1,21 @@ from pathlib import Path import torch - from nanotron.core import distributed as dist -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.random import RandomStates +from nanotron.distributed import ParallelContext def save_random_states( random_states: RandomStates, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, root_folder: Path, ): """All processes save their own random state""" filename = ( root_folder / "random" - / f"tp-{dist.get_rank(dpg.tp_pg)}-of-{dpg.tp_pg.size()}_dp-{dist.get_rank(dpg.dp_pg)}-of-{dpg.dp_pg.size()}_pp-{dist.get_rank(dpg.pp_pg)}-of-{dpg.pp_pg.size()}.pt" + / f"tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}.pt" ) filename.parent.mkdir(exist_ok=True, parents=True) @@ -24,12 +23,12 @@ def save_random_states( torch.save(random_states, filename) -def load_random_states(dpg: DistributedProcessGroups, root_folder: Path): +def load_random_states(parallel_context: ParallelContext, root_folder: Path): # TODO @thomasw21: This basically assumes that we have exactly the same topology as the one we used when saving. filename = ( root_folder / "random" - / f"tp-{dist.get_rank(dpg.tp_pg)}-of-{dpg.tp_pg.size()}_dp-{dist.get_rank(dpg.dp_pg)}-of-{dpg.dp_pg.size()}_pp-{dist.get_rank(dpg.pp_pg)}-of-{dpg.pp_pg.size()}.pt" + / f"tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}.pt" ) # TODO @thomasw21: That's annothing but this actually uses pickle, we might need to change that for something else diff --git a/src/nanotron/serialize/utils.py b/src/nanotron/serialize/utils.py index e8a113ee..1a27bb3d 100644 --- a/src/nanotron/serialize/utils.py +++ b/src/nanotron/serialize/utils.py @@ -1,7 +1,7 @@ from enum import Enum from typing import List, Optional, Tuple -from nanotron.core.process_groups import DistributedProcessGroups +from nanotron.distributed import ParallelContext class ObjectType(Enum): @@ -11,10 +11,10 @@ class ObjectType(Enum): def get_tp_and_pp_rank_and_size_from( - world_rank: int, dpg: DistributedProcessGroups + world_rank: int, parallel_context: ParallelContext ) -> Tuple[Tuple[int, int], Tuple[int, int]]: - result = dpg.get_3d_ranks(world_rank=world_rank) - return (result[2], dpg.tp_pg.size()), (result[0], dpg.pp_pg.size()) + result = parallel_context.get_3d_ranks(world_rank=world_rank) + return (result[2], parallel_context.tp_pg.size()), (result[0], parallel_context.pp_pg.size()) def get_path( diff --git a/src/nanotron/serialize/weights.py b/src/nanotron/serialize/weights.py index 598b82c5..1a250f20 100644 --- a/src/nanotron/serialize/weights.py +++ b/src/nanotron/serialize/weights.py @@ -3,17 +3,12 @@ import dacite import torch -from packaging.version import Version -from safetensors.torch import safe_open, save_file -from torch import nn -from tqdm import tqdm - from nanotron import logging from nanotron.constants import CHECKPOINT_VERSION from nanotron.core import distributed as dist from nanotron.core.distributed import get_global_rank from nanotron.core.parallel.parameters import NanotronParameter, ShardedInfo, SlicesPair -from nanotron.core.process_groups import DistributedProcessGroups +from nanotron.distributed import ParallelContext from nanotron.logging import log_rank from nanotron.serialize.metadata import CheckpointMetadata, TensorMetadata, TensorMetadataV2, load_meta from nanotron.serialize.utils import ( @@ -21,16 +16,20 @@ get_path, get_tp_and_pp_rank_and_size_from, ) +from packaging.version import Version +from safetensors.torch import safe_open, save_file +from torch import nn +from tqdm import tqdm logger = logging.get_logger(__name__) -def save_weights(model: nn.Module, dpg: DistributedProcessGroups, root_folder: Path): +def save_weights(model: nn.Module, parallel_context: ParallelContext, root_folder: Path): root_folder = root_folder / "model" - # We save only `dist.get_rank(dpg.dp_pg) == 0` + # We save only `dist.get_rank(parallel_context.dp_pg) == 0` # TODO @thomasw21: Figure how this works with Zero-3 - if dist.get_rank(dpg.dp_pg) != 0: + if dist.get_rank(parallel_context.dp_pg) != 0: return module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} @@ -53,7 +52,7 @@ def save_weights(model: nn.Module, dpg: DistributedProcessGroups, root_folder: P tied_info = param.get_tied_info() base_name = tied_info.get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) group_ranks = tied_info.global_ranks - group = dpg.world_ranks_to_pg[group_ranks] + group = parallel_context.world_ranks_to_pg[group_ranks] # Only the first rank of the group of the tied weights saves weights # TODO @thomasw21: We could rotate in order to balance the load. if dist.get_rank(group) != 0: @@ -63,9 +62,10 @@ def save_weights(model: nn.Module, dpg: DistributedProcessGroups, root_folder: P if param.is_sharded: sharded_info = param.get_sharded_info() - group = dpg.world_ranks_to_pg[sharded_info.global_ranks] + group = parallel_context.world_ranks_to_pg[sharded_info.global_ranks] tp_and_pp_rank_and_size = get_tp_and_pp_rank_and_size_from( - world_rank=get_global_rank(group=group, group_rank=dist.get_rank(group)), dpg=dpg + world_rank=get_global_rank(group=group, group_rank=dist.get_rank(group)), + parallel_context=parallel_context, ) metadata = TensorMetadataV2( version=CHECKPOINT_VERSION, @@ -108,13 +108,13 @@ def read_checkpoint_version_from_shard_file(param_save_path: Path) -> Version: return checkpoint_version -def read_checkpoint_version_from_meta(dpg: DistributedProcessGroups, root_folder: Path) -> Version: - checkpoint_metadata: CheckpointMetadata = load_meta(dpg=dpg, root_folder=root_folder) +def read_checkpoint_version_from_meta(parallel_context: ParallelContext, root_folder: Path) -> Version: + checkpoint_metadata: CheckpointMetadata = load_meta(parallel_context=parallel_context, root_folder=root_folder) checkpoint_version = checkpoint_metadata.version return checkpoint_version -def get_checkpoint_version(dpg, root_folder, param_save_path: Path) -> Version: +def get_checkpoint_version(parallel_context, root_folder, param_save_path: Path) -> Version: try: checkpoint_version = read_checkpoint_version_from_shard_file(param_save_path=param_save_path) except CheckpointVersionFromShardFileException: @@ -124,7 +124,9 @@ def get_checkpoint_version(dpg, root_folder, param_save_path: Path) -> Version: level=logging.ERROR, rank=0, ) - checkpoint_version = read_checkpoint_version_from_meta(dpg=dpg, root_folder=root_folder) + checkpoint_version = read_checkpoint_version_from_meta( + parallel_context=parallel_context, root_folder=root_folder + ) return checkpoint_version @@ -207,7 +209,7 @@ def load_sharded_param_latest(param_or_buffer: torch.Tensor, sharded_info: Shard def load_weights( model: nn.Module, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, root_folder: Path, filtered_state_dict: Optional[Dict[str, Any]] = None, ): @@ -215,7 +217,7 @@ def load_weights( Args: model: model to load weights into - dpg: distributed process groups + parallel_context: distributed process groups root_folder: root folder of the checkpoint filtered_state_dict: state dict to load from (overrides model.state_dict()). if None, load from model.state_dict() """ @@ -229,7 +231,7 @@ def load_weights( filtered_state_dict = filtered_state_dict if filtered_state_dict is not None else model.state_dict() for name, param_or_buffer in tqdm( - filtered_state_dict.items(), disable=dist.get_rank(dpg.world_pg) != 0, desc="Loading weights" + filtered_state_dict.items(), disable=dist.get_rank(parallel_context.world_pg) != 0, desc="Loading weights" ): # `state_dict` doesn't return a Param or a buffer, just a tensors which loses some metadata try: @@ -249,14 +251,14 @@ def load_weights( if param.is_tied: # When params are tied only the first rank of tied param group stores weights (see save_weights) - group = dpg.world_ranks_to_pg[tied_info.global_ranks] + group = parallel_context.world_ranks_to_pg[tied_info.global_ranks] group_rank = 0 else: - group = dpg.world_ranks_to_pg[sharded_info.global_ranks] + group = parallel_context.world_ranks_to_pg[sharded_info.global_ranks] group_rank = dist.get_rank(group) tp_and_pp_rank_and_size = get_tp_and_pp_rank_and_size_from( - world_rank=get_global_rank(group=group, group_rank=group_rank), dpg=dpg + world_rank=get_global_rank(group=group, group_rank=group_rank), parallel_context=parallel_context ) else: tp_and_pp_rank_and_size = None @@ -294,7 +296,9 @@ def load_weights( raise ValueError(f"Could not find any shards in {path.parent}") if checkpoint_version is None: - checkpoint_version = get_checkpoint_version(dpg, root_folder, param_save_path=shards_path[0]) + checkpoint_version = get_checkpoint_version( + parallel_context, root_folder, param_save_path=shards_path[0] + ) else: current_checkpoint_version = None try: @@ -330,7 +334,7 @@ def load_weights( def get_checkpoint_paths_list( model: nn.Module, - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, root_folder: Path, only_list_folders: bool = False, only_list_current_process: bool = True, @@ -340,7 +344,7 @@ def get_checkpoint_paths_list( Args: model: model to load weights into - dpg: distributed process groups + parallel_context: distributed process groups root_folder: root folder of the checkpoint filtered_state_dict: state dict to load from (overrides model.state_dict()). if None, load from model.state_dict() """ @@ -354,7 +358,9 @@ def get_checkpoint_paths_list( filtered_state_dict = filtered_state_dict if filtered_state_dict is not None else model.state_dict() for name in tqdm( - filtered_state_dict.values(), disable=dist.get_rank(dpg.world_pg) != 0, desc="Listing checkpoint paths" + filtered_state_dict.values(), + disable=dist.get_rank(parallel_context.world_pg) != 0, + desc="Listing checkpoint paths", ): # `state_dict` doesn't return a Param or a buffer, just a tensors which loses some metadata try: @@ -374,14 +380,14 @@ def get_checkpoint_paths_list( if param.is_tied: # When params are tied only the first rank of tied param group stores weights (see save_weights) - group = dpg.world_ranks_to_pg[tied_info.global_ranks] + group = parallel_context.world_ranks_to_pg[tied_info.global_ranks] group_rank = 0 else: - group = dpg.world_ranks_to_pg[sharded_info.global_ranks] + group = parallel_context.world_ranks_to_pg[sharded_info.global_ranks] group_rank = dist.get_rank(group) tp_and_pp_rank_and_size = get_tp_and_pp_rank_and_size_from( - world_rank=get_global_rank(group=group, group_rank=group_rank), dpg=dpg + world_rank=get_global_rank(group=group, group_rank=group_rank), parallel_context=parallel_context ) else: tp_and_pp_rank_and_size = None diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 1370ebef..599c263a 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -12,8 +12,6 @@ import numpy as np import torch -from torch.nn.parallel import DistributedDataParallel - from nanotron import logging from nanotron.config import ( Config, @@ -42,7 +40,6 @@ sync_tied_weights_gradients, tie_parameters, ) -from nanotron.core.process_groups import DistributedProcessGroups, get_process_groups from nanotron.core.random import ( set_random_seed, ) @@ -53,6 +50,7 @@ init_on_device_and_dtype, ) from nanotron.dataloader import sanity_check_dataloader +from nanotron.distributed import ParallelContext from nanotron.helpers import ( _vocab_size_with_padding, get_profiler, @@ -73,6 +71,7 @@ save, save_random_states, ) +from torch.nn.parallel import DistributedDataParallel if int(os.environ.get("USE_FAST", 0)) == 1: # We import the fast versions @@ -120,19 +119,21 @@ def __init__(self, config_or_config_file: Union[Config, str]): ######################################## # Initialise all process groups - self.dpg = get_process_groups( - data_parallel_size=self.config.parallelism.dp, - pipeline_parallel_size=self.config.parallelism.pp, + self.parallel_context = ParallelContext( tensor_parallel_size=self.config.parallelism.tp, + pipeline_parallel_size=self.config.parallelism.pp, + data_parallel_size=self.config.parallelism.dp, ) # Set log levels - if dist.get_rank(self.dpg.world_pg) == 0: + if dist.get_rank(self.parallel_context.world_pg) == 0: if self.config.logging.log_level is not None: - set_logger_verbosity_format(self.config.logging.log_level, dpg=self.dpg) + set_logger_verbosity_format(self.config.logging.log_level, parallel_context=self.parallel_context) else: if self.config.logging.log_level_replica is not None: - set_logger_verbosity_format(self.config.logging.log_level_replica, dpg=self.dpg) + set_logger_verbosity_format( + self.config.logging.log_level_replica, parallel_context=self.parallel_context + ) ######################################## ## Do a couple of NCCL and CUDA tests to catch faulty nodes @@ -140,21 +141,21 @@ def __init__(self, config_or_config_file: Union[Config, str]): # Do a first NCCL sync to warmup and try to avoid Timeout after model/data loading log_rank( - f"[TEST] Running NCCL sync for ranks {list(range(self.dpg.world_pg.size()))}", + f"[TEST] Running NCCL sync for ranks {list(range(self.parallel_context.world_pg.size()))}", logger=logger, level=logging.WARNING, - group=self.dpg.dp_pg, + group=self.parallel_context.dp_pg, rank=0, ) - test_tensor = torch.tensor([dist.get_rank(self.dpg.world_pg)], device=torch.device("cuda")) - test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(self.dpg.world_pg.size())] - dist.all_gather(test_tensor_list, test_tensor, group=self.dpg.world_pg, async_op=False) + test_tensor = torch.tensor([dist.get_rank(self.parallel_context.world_pg)], device=torch.device("cuda")) + test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(self.parallel_context.world_pg.size())] + dist.all_gather(test_tensor_list, test_tensor, group=self.parallel_context.world_pg, async_op=False) dist.barrier() log_rank( f"[TEST] NCCL sync for ranks {[t.item() for t in test_tensor_list]}", logger=logger, level=logging.WARNING, - group=self.dpg.dp_pg, + group=self.parallel_context.dp_pg, rank=0, ) @@ -166,7 +167,7 @@ def __init__(self, config_or_config_file: Union[Config, str]): f"[TEST] free memory free_mem: {human_format(free_mem)}, total_mem: {human_format(total_mem)}", logger=logger, level=logging.WARNING, - group=self.dpg.world_pg, + group=self.parallel_context.world_pg, rank=None, ) if free_mem < MIN_GPU_MEM_THRESHOLD: @@ -178,7 +179,7 @@ def __init__(self, config_or_config_file: Union[Config, str]): f"[TEST] Allocated a tensor of size {human_format(test_tensor_size)} (90% of free memory)", logger=logger, level=logging.WARNING, - group=self.dpg.world_pg, + group=self.parallel_context.world_pg, rank=None, ) del test_tensor @@ -188,7 +189,7 @@ def __init__(self, config_or_config_file: Union[Config, str]): # Log benchmark info if os.environ.get("NANOTRON_BENCHMARK", "0") == "1": - log_throughput(self.config, self.dpg) + log_throughput(self.config, self.parallel_context) ######################################## ## Setting up our model, optimizers, schedulers, etc. @@ -198,16 +199,20 @@ def __init__(self, config_or_config_file: Union[Config, str]): set_random_seed(self.config.general.seed) # Init model and build on pp ranks - self.random_states = init_random_states(parallel_config=self.config.parallelism, tp_pg=self.dpg.tp_pg) + self.random_states = init_random_states( + parallel_config=self.config.parallelism, tp_pg=self.parallel_context.tp_pg + ) self.model, checkpoint_path = self.init_model() # Defines self.model self.normalized_model = self.model.module if isinstance(self.model, DistributedDataParallel) else self.model # Init optimizer self.optimizer, self.grad_accumulator = init_optimizer_and_grad_accumulator( - model=self.model, optimizer_args=self.config.optimizer, dpg=self.dpg + model=self.model, optimizer_args=self.config.optimizer, parallel_context=self.parallel_context ) if checkpoint_path is not None: - load_optimizer(optimizer=self.optimizer, dpg=self.dpg, root_folder=checkpoint_path) + load_optimizer( + optimizer=self.optimizer, parallel_context=self.parallel_context, root_folder=checkpoint_path + ) # Init learning rate scheduler self.lr_scheduler = lr_scheduler_builder( @@ -225,7 +230,7 @@ def __init__(self, config_or_config_file: Union[Config, str]): self.start_iteration_step: int self.consumed_train_samples: int if checkpoint_path is not None: - checkpoint_metadata = load_meta(dpg=self.dpg, root_folder=checkpoint_path) + checkpoint_metadata = load_meta(parallel_context=self.parallel_context, root_folder=checkpoint_path) log_rank(str(checkpoint_metadata), logger=logger, level=logging.INFO, rank=0) self.start_iteration_step = checkpoint_metadata.metas["last_train_step"] self.consumed_train_samples = checkpoint_metadata.metas["consumed_train_samples"] @@ -237,11 +242,13 @@ def __init__(self, config_or_config_file: Union[Config, str]): self.consumed_train_samples = 0 # Setup tensorboard write and log writers on output rank - self.logger_ranks = self.dpg.world_rank_matrix[self.normalized_model.output_pp_rank, 0, 0].flatten() + self.logger_ranks = self.parallel_context.world_rank_matrix[ + self.normalized_model.output_pp_rank, 0, 0 + ].flatten() self.loggerwriter = self.setup_log_writers() # Log where each module is instantiated - self.normalized_model.log_modules(level=logging.DEBUG, group=self.dpg.world_pg, rank=0) + self.normalized_model.log_modules(level=logging.DEBUG, group=self.parallel_context.world_pg, rank=0) # Log config and model config # self.log_object(self.config, "config") @@ -292,26 +299,28 @@ def __init__(self, config_or_config_file: Union[Config, str]): # self.log_object(slurm_dict, "slurm") # Do a first NCCL sync to warmup and try to avoid Timeout after model/data loading - test_tensor = torch.tensor([dist.get_rank(self.dpg.world_pg)], device=torch.device("cuda")) - test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(self.dpg.world_pg.size())] - dist.all_gather(test_tensor_list, test_tensor, group=self.dpg.world_pg, async_op=False) + test_tensor = torch.tensor([dist.get_rank(self.parallel_context.world_pg)], device=torch.device("cuda")) + test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(self.parallel_context.world_pg.size())] + dist.all_gather(test_tensor_list, test_tensor, group=self.parallel_context.world_pg, async_op=False) dist.barrier() log_rank( f"[SECOND TEST] NCCL sync for ranks {[t.item() for t in test_tensor_list]}", logger=logger, level=logging.WARNING, - group=self.dpg.dp_pg, + group=self.parallel_context.dp_pg, rank=0, ) log_rank( - f"Global rank: { dist.get_rank(self.dpg.world_pg)}/{self.dpg.world_pg.size()} | PP: {dist.get_rank(self.dpg.pp_pg)}/{self.dpg.pp_pg.size()} | DP: {dist.get_rank(self.dpg.dp_pg)}/{self.dpg.dp_pg.size()} | TP: {dist.get_rank(self.dpg.tp_pg)}/{self.dpg.tp_pg.size()}", + f"Global rank: { dist.get_rank(self.parallel_context.world_pg)}/{self.parallel_context.world_pg.size()} | PP: {dist.get_rank(self.parallel_context.pp_pg)}/{self.parallel_context.pp_pg.size()} | DP: {dist.get_rank(self.parallel_context.dp_pg)}/{self.parallel_context.dp_pg.size()} | TP: {dist.get_rank(self.parallel_context.tp_pg)}/{self.parallel_context.tp_pg.size()}", logger=logger, level=logging.INFO, ) self.micro_batch_size = self.config.tokens.micro_batch_size self.n_micro_batches_per_batch = self.config.tokens.batch_accumulation_per_replica - self.global_batch_size = self.micro_batch_size * self.n_micro_batches_per_batch * self.dpg.dp_pg.size() + self.global_batch_size = ( + self.micro_batch_size * self.n_micro_batches_per_batch * self.parallel_context.dp_pg.size() + ) self.sequence_length = self.config.tokens.sequence_length self.iteration_step = self.start_iteration_step self.limit_val_batches = self.config.tokens.limit_val_batches @@ -332,11 +341,11 @@ def __init__(self, config_or_config_file: Union[Config, str]): # ) # else: # self.s3_mover = None - # if self.config.checkpoints.lighteval is not None and dist.get_rank(self.dpg.world_pg) == 0: + # if self.config.checkpoints.lighteval is not None and dist.get_rank(self.parallel_context.world_pg) == 0: # # We only start evaluation runs on the first node # if self.s3_mover is None: # raise ValueError("lighteval requires s3 upload of checkpoints to be enabled") - # self.lighteval_runner = LightEvalRunner(config=self.config, dpg=self.dpg) + # self.lighteval_runner = LightEvalRunner(config=self.config, parallel_context=self.parallel_context) # self.s3_mover.post_upload_callback = self.lighteval_runner.eval_single_checkpoint if self.config.checkpoints.save_initial_state and checkpoint_path is None: @@ -370,7 +379,9 @@ def train( dataloader = dataloader_or_dls[0] else: dataloader = dataloader_or_dls - dataloader = sanity_check_dataloader(dataloader=dataloader, dpg=self.dpg, config=self.config) + dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) # Log data config # self.log_object(data_config_log, name="data_config") @@ -448,7 +459,7 @@ def train( # self.tb_context.scheduler.trigger() # if self.s3_mover is not None: - # self.s3_mover.distributed_wait_for_completion(group=self.dpg.world_pg) + # self.s3_mover.distributed_wait_for_completion(group=self.parallel_context.world_pg) def training_step( self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]] @@ -462,14 +473,14 @@ def training_step( f" Peak reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f}MiB", logger=logger, level=logging.INFO, - group=self.dpg.world_pg, + group=self.parallel_context.world_pg, rank=0, ) torch.cuda.reset_peak_memory_stats() outputs = self.pipeline_engine.train_batch_iter( model=self.model, - pg=self.dpg.pp_pg, + pg=self.parallel_context.pp_pg, batch=(next(dataloader) for _ in range(self.n_micro_batches_per_batch)), nb_microbatches=self.n_micro_batches_per_batch, grad_accumulator=self.grad_accumulator, @@ -482,7 +493,7 @@ def training_step( f" Peak reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f}MiB", logger=logger, level=logging.INFO, - group=self.dpg.world_pg, + group=self.parallel_context.world_pg, rank=0, ) torch.cuda.reset_peak_memory_stats() @@ -501,7 +512,7 @@ def training_step( # Manually sync across DP if it's not handled by DDP sync_gradients_across_dp( module=self.model, - dp_pg=self.dpg.dp_pg, + dp_pg=self.parallel_context.dp_pg, reduce_op=dist.ReduceOp.AVG, # TODO @thomasw21: This is too memory hungry, instead we run all_reduce reduce_scatter=False, # optimizer.inherit_from(ZeroDistributedOptimizer), @@ -511,7 +522,7 @@ def training_step( # TODO @nouamane: Put this in hooks so we can overlap communication with gradient computation on the last backward pass. sync_tied_weights_gradients( module=self.normalized_model, - dpg=self.dpg, + parallel_context=self.parallel_context, grad_accumulator=self.grad_accumulator, ) @@ -532,8 +543,14 @@ def training_step( ] # TODO @nouamane: we need to split `world_rank_matrix` along PP axis, to separate ref from active model self.grad_norm_unclipped = clip_grad_norm( - mp_pg=self.dpg.world_ranks_to_pg[ - tuple(sorted(self.dpg.world_rank_matrix[:, dist.get_rank(self.dpg.dp_pg), :].reshape(-1))) + mp_pg=self.parallel_context.world_ranks_to_pg[ + tuple( + sorted( + self.parallel_context.world_rank_matrix[ + :, dist.get_rank(self.parallel_context.dp_pg), : + ].reshape(-1) + ) + ) ], named_parameters=named_parameters, grad_accumulator=self.grad_accumulator, @@ -549,7 +566,7 @@ def training_step( [output["loss"] for output in outputs] ).sum() # already divided by n_micro_batches_per_batch # sync loss across DP - handle = dist.all_reduce(loss_avg, group=self.dpg.dp_pg, async_op=True, op=dist.ReduceOp.AVG) + handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG) else: loss_avg = None handle = None @@ -593,7 +610,7 @@ def train_step_logs( global_batch_size=self.global_batch_size, ) - if dist.get_rank(self.dpg.world_pg) in self.logger_ranks: + if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" lr = self.lr_scheduler.get_last_lr()[0] @@ -606,7 +623,7 @@ def train_step_logs( LogItem("elapsed_time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format"), # , ".1f"), LogItem("tokens_per_sec", tokens_per_sec, "human_format"), # , "1.6E"), LogItem( - "tokens_per_sec_per_gpu", tokens_per_sec / self.dpg.world_pg.size(), "human_format" + "tokens_per_sec_per_gpu", tokens_per_sec / self.parallel_context.world_pg.size(), "human_format" ), # , "1.6E"), LogItem("global_batch_size", self.global_batch_size, "human_format"), # , "5d"), LogItem("lm_loss", loss_avg.item(), "human_format"), # , "1.6E"), @@ -649,7 +666,7 @@ def train_step_logs( if os.environ.get("NANOTRON_BENCHMARK", "0") == "1" and self.iteration_step == 3: log_throughput( self.config, - self.dpg, + self.parallel_context, model_tflops, hardware_tflops, tokens_per_sec, @@ -663,25 +680,27 @@ def train_step_logs( @staticmethod def build_model( model_builder: Callable[[], NanotronModel], - dpg: DistributedProcessGroups, + parallel_context: ParallelContext, dtype: torch.dtype, target_pp_ranks: Optional[List[int]] = None, device: Optional[torch.device] = torch.device("cuda"), ) -> NanotronModel: """Build the model and set the pp ranks for each pipeline block.""" # TODO: classes dont take same args - log_rank("Building model..", logger=logger, level=logging.INFO, rank=0, group=dpg.world_pg) + log_rank("Building model..", logger=logger, level=logging.INFO, rank=0, group=parallel_context.world_pg) model: NanotronModel = model_builder() # If no target pp ranks are specified, we assume that we want to use all pp ranks if target_pp_ranks is None: - pp_size = dpg.pp_pg.size() + pp_size = parallel_context.pp_pg.size() target_pp_ranks = list(range(pp_size)) else: pp_size = len(target_pp_ranks) # Set rank for each pipeline block - log_rank("Setting PP block ranks..", logger=logger, level=logging.INFO, rank=0, group=dpg.world_pg) + log_rank( + "Setting PP block ranks..", logger=logger, level=logging.INFO, rank=0, group=parallel_context.world_pg + ) pipeline_blocks = [module for name, module in model.named_modules() if isinstance(module, PipelineBlock)] # "cuda" is already defaulted for each process to it's own cuda device with init_on_device_and_dtype(device=device, dtype=dtype): @@ -716,7 +735,7 @@ def init_model(self) -> Tuple[NanotronModel, Optional[str]]: # TODO: add max_position_embeddings self.model_config.vocab_size = _vocab_size_with_padding( self.model_config.vocab_size, - pg_size=self.dpg.tp_pg.size(), + pg_size=self.parallel_context.tp_pg.size(), make_vocab_size_divisible_by=self.config.model.make_vocab_size_divisible_by, ) @@ -752,7 +771,7 @@ def init_model(self) -> Tuple[NanotronModel, Optional[str]]: model = self._init_model( model_builder=lambda: CONFIG_TO_MODEL_CLASS[model_config_cls]( config=self.model_config, - dpg=self.dpg, + parallel_context=self.parallel_context, parallel_config=self.config.parallelism, random_states=self.random_states, ), @@ -765,13 +784,17 @@ def init_model(self) -> Tuple[NanotronModel, Optional[str]]: if checkpoint_path is not None: # Reload from a training checkpoint log_rank(f"Loading weights from {checkpoint_path}", logger=logger, level=logging.INFO, rank=0) - load_weights(model=normalized_model, dpg=self.dpg, root_folder=checkpoint_path) + load_weights(model=normalized_model, parallel_context=self.parallel_context, root_folder=checkpoint_path) reloaded_from_checkpoint = True if not reloaded_from_checkpoint: log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO) if isinstance(self.config.model.init_method, ExistingCheckpointInit): # Initialize model from an pretrained model checkpoint - load_weights(model=normalized_model, dpg=self.dpg, root_folder=self.config.model.init_method.path) + load_weights( + model=normalized_model, + parallel_context=self.parallel_context, + root_folder=self.config.model.init_method.path, + ) elif isinstance(self.config.model.init_method, RandomInit): # Initialize model randomly normalized_model.init_model_randomly( @@ -783,7 +806,7 @@ def init_model(self) -> Tuple[NanotronModel, Optional[str]]: # Synchronize parameters so that the model is consistent # sync all params across dp for name, param in sorted(model.named_parameters(), key=lambda x: x[0]): - dist.all_reduce(param, op=dist.ReduceOp.AVG, group=self.dpg.dp_pg) + dist.all_reduce(param, op=dist.ReduceOp.AVG, group=self.parallel_context.dp_pg) # sync tied params across tied groups for (_, group_ranks), param in sorted( @@ -793,7 +816,7 @@ def init_model(self) -> Tuple[NanotronModel, Optional[str]]: ).items(), key=lambda x: x[0], ): - group = self.dpg.world_ranks_to_pg[group_ranks] + group = self.parallel_context.world_ranks_to_pg[group_ranks] dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group) else: raise ValueError(f"Unsupported {self.config.model.init_method}") @@ -806,14 +829,14 @@ def _init_model( target_pp_ranks: Optional[List[int]] = None, ) -> NanotronModel: config = self.config - dpg = self.dpg + parallel_context = self.parallel_context parallel_config = config.parallelism make_ddp = not (config.optimizer.accumulate_grad_in_fp32 and config.optimizer.zero_stage > 0) # Build model and set pp ranks model = self.build_model( - dpg=dpg, + parallel_context=parallel_context, dtype=config.model.dtype, target_pp_ranks=target_pp_ranks, model_builder=model_builder, @@ -826,31 +849,31 @@ def _init_model( module.init_rotary_embeddings() # Mark some parameters as tied - mark_tied_parameters(model=model, dpg=dpg, parallel_config=parallel_config) + mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) # count number of parameters num_params = sum(p.numel() for p in model.parameters()) size_params = sum(p.numel() * p.element_size() for p in model.parameters()) total_params = torch.tensor(num_params, device="cuda") total_size = torch.tensor(size_params, device="cuda") - dist.all_reduce(total_params, group=dpg.tp_pg, async_op=False, op=dist.ReduceOp.SUM) # TP - dist.all_reduce(total_params, group=dpg.pp_pg, async_op=False, op=dist.ReduceOp.SUM) # PP - dist.all_reduce(total_size, group=dpg.tp_pg, async_op=False, op=dist.ReduceOp.SUM) - dist.all_reduce(total_size, group=dpg.pp_pg, async_op=False, op=dist.ReduceOp.SUM) + dist.all_reduce(total_params, group=parallel_context.tp_pg, async_op=False, op=dist.ReduceOp.SUM) # TP + dist.all_reduce(total_params, group=parallel_context.pp_pg, async_op=False, op=dist.ReduceOp.SUM) # PP + dist.all_reduce(total_size, group=parallel_context.tp_pg, async_op=False, op=dist.ReduceOp.SUM) + dist.all_reduce(total_size, group=parallel_context.pp_pg, async_op=False, op=dist.ReduceOp.SUM) # TODO @nouamanetazi: better memory logs log_rank( f"Total number of parameters: {human_format(total_params.item())} ({total_size.item() / 1024**2:.2f}MiB)", logger=logger, level=logging.INFO, - group=dpg.world_pg, + group=parallel_context.world_pg, rank=0, ) log_rank( f"Local number of parameters: {human_format(num_params)} ({size_params / 1024**2:.2f}MiB)", logger=logger, level=logging.INFO, - group=dpg.dp_pg, + group=parallel_context.dp_pg, rank=0, ) log_rank( @@ -859,7 +882,7 @@ def _init_model( f" Peak reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f}MiB", logger=logger, level=logging.INFO, - group=dpg.dp_pg, + group=parallel_context.dp_pg, rank=0, ) @@ -867,7 +890,10 @@ def _init_model( if make_ddp is True: # TODO @thomasw21: DDP doesn't support broadcasting complex buffers (and we don't really need that broadcasting anyway) model = DistributedDataParallel( - model, process_group=dpg.dp_pg, broadcast_buffers=False, bucket_cap_mb=config.model.ddp_bucket_cap_mb + model, + process_group=parallel_context.dp_pg, + broadcast_buffers=False, + bucket_cap_mb=config.model.ddp_bucket_cap_mb, ) # Sanity check the model, all parameters must be NanotronParameter (either tied or sharded) @@ -883,9 +909,9 @@ def setup_log_writers( Args: config (Config): The config object logger_ranks (Iterable[int]): The ranks that should log - dpg (DistributedProcessGroups): The distributed process groups + parallel_context (DistributedProcessGroups): The distributed process groups """ - if dist.get_rank(self.dpg.world_pg) in self.logger_ranks: + if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: loggerwriter = LoggerWriter(global_step=self.config.tokens.train_steps) else: loggerwriter = None @@ -909,7 +935,7 @@ def check_kill_switch(self, save_ckpt: bool): def save_checkpoint(self) -> Path: # if self.s3_mover is not None: - # self.s3_mover.distributed_wait_for_completion(self.dpg.world_pg) + # self.s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg) # if self.s3_mover.post_upload_callback_outputs is not None: # slurm_job_id, slurm_log = self.s3_mover.post_upload_callback_outputs # self.log_object({"job_id": slurm_job_id, "log": slurm_log}, "slurm_eval") @@ -917,12 +943,12 @@ def save_checkpoint(self) -> Path: checkpoints_path = self.config.checkpoints.checkpoints_path checkpoint_path = checkpoints_path / f"{self.iteration_step}" if self.config.checkpoints.checkpoints_path_is_shared_file_system: - should_mkdir = dist.get_rank(self.dpg.world_pg) == 0 + should_mkdir = dist.get_rank(self.parallel_context.world_pg) == 0 else: should_mkdir = bool(int(os.environ.get("LOCAL_RANK", None)) == 0) if should_mkdir: checkpoint_path.mkdir(parents=True, exist_ok=True) - dist.barrier(self.dpg.world_pg) + dist.barrier(self.parallel_context.world_pg) log_rank(f"Saving checkpoint at {checkpoint_path}", logger=logger, level=logging.WARNING, rank=0) checkpoint_metadata = { @@ -939,18 +965,24 @@ def save_checkpoint(self) -> Path: model=self.normalized_model, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler, - should_save_model=bool(dist.get_rank(self.dpg.dp_pg) == 0), # We only save the weights on DP==0 + should_save_model=bool( + dist.get_rank(self.parallel_context.dp_pg) == 0 + ), # We only save the weights on DP==0 should_save_optimizer=True, should_save_lr_scheduler=bool( - dist.get_rank(self.dpg.world_pg) == 0 + dist.get_rank(self.parallel_context.world_pg) == 0 ), # We only save the lr_scheduler on world_rank==0 - should_save_config=bool(dist.get_rank(self.dpg.world_pg) == 0), # We only save the config on world_rank==0 - dpg=self.dpg, + should_save_config=bool( + dist.get_rank(self.parallel_context.world_pg) == 0 + ), # We only save the config on world_rank==0 + parallel_context=self.parallel_context, root_folder=checkpoint_path, checkpoint_metadata=checkpoint_metadata, config=self.config, ) - save_random_states(random_states=self.random_states, dpg=self.dpg, root_folder=checkpoint_path) + save_random_states( + random_states=self.random_states, parallel_context=self.parallel_context, root_folder=checkpoint_path + ) with open(checkpoints_path / "latest.txt", mode="w") as fo: fo.write(f"{self.iteration_step}") @@ -971,7 +1003,9 @@ def before_tbi_sanity_checks(self) -> None: # SANITY CHECK: Check that the model params are synchronized across dp for name, param in sorted(self.model.named_parameters(), key=lambda x: x[0]): assert_tensor_synced_across_pg( - tensor=param, pg=self.dpg.dp_pg, msg=lambda err: f"{name} are not synchronized across DP {err}" + tensor=param, + pg=self.parallel_context.dp_pg, + msg=lambda err: f"{name} are not synchronized across DP {err}", ) # SANITY CHECK: Tied weights are synchronized @@ -983,7 +1017,7 @@ def before_tbi_sanity_checks(self) -> None: key=lambda x: x[0], ) for (name, group_ranks), param in tied_params_list: - group = self.dpg.world_ranks_to_pg[group_ranks] + group = self.parallel_context.world_ranks_to_pg[group_ranks] assert_tensor_synced_across_pg( tensor=param, pg=group, @@ -1029,7 +1063,7 @@ def after_tbi_sanity_checks(self) -> None: raise ValueError("Gradient is nan or inf") if grad is None: log_rank( - f"Process rank { dist.get_rank(self.dpg.world_pg)}/{self.dpg.world_pg.size()}: {name} is missing gradient", + f"Process rank { dist.get_rank(self.parallel_context.world_pg)}/{self.parallel_context.world_pg.size()}: {name} is missing gradient", logger=logger, level=logging.ERROR, ) @@ -1055,7 +1089,7 @@ def before_optim_step_sanity_checks(self) -> None: grad = param.grad assert grad is not None, f"Grad is None for {name}" - group = self.dpg.world_ranks_to_pg[group_ranks] + group = self.parallel_context.world_ranks_to_pg[group_ranks] assert_tensor_synced_across_pg( tensor=grad, pg=group, @@ -1081,14 +1115,16 @@ def before_optim_step_sanity_checks(self) -> None: assert grad is not None, f"Grad is None for {name}" assert_tensor_synced_across_pg( tensor=grad, - pg=self.dpg.dp_pg, + pg=self.parallel_context.dp_pg, msg=lambda err: f"[Before optimizer step] weights grads for {name} are not synchronized across DP. {err}", ) # SANITY CHECK: Check that the model params are synchronized across dp for name, param in sorted(self.model.named_parameters(), key=lambda x: x[0]): assert_tensor_synced_across_pg( - tensor=param, pg=self.dpg.dp_pg, msg=lambda err: f"{name} are not synchronized across DP {err}" + tensor=param, + pg=self.parallel_context.dp_pg, + msg=lambda err: f"{name} are not synchronized across DP {err}", ) # SANITY CHECK: Tied weights are synchronized @@ -1100,7 +1136,7 @@ def before_optim_step_sanity_checks(self) -> None: ) for (name, group_ranks), param in tied_params_list: - group = self.dpg.world_ranks_to_pg[group_ranks] + group = self.parallel_context.world_ranks_to_pg[group_ranks] assert_tensor_synced_across_pg( tensor=param, pg=group, @@ -1119,7 +1155,7 @@ def after_optim_step_sanity_checks(self) -> None: if param.grad is not None: log_rank( - f"Process rank { dist.get_rank(self.dpg.world_pg)}/{self.dpg.world_pg.size()}: {name} still has gradient despite having ran the optimizer", + f"Process rank { dist.get_rank(self.parallel_context.world_pg)}/{self.parallel_context.world_pg.size()}: {name} still has gradient despite having ran the optimizer", logger=logger, level=logging.ERROR, ) @@ -1129,7 +1165,7 @@ def after_optim_step_sanity_checks(self) -> None: def mark_tied_parameters( - model: NanotronModel, dpg: DistributedProcessGroups, parallel_config: Optional[ParallelismArgs] = None + model: NanotronModel, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs] = None ): # Tie embeddings embeddings_lm_head_tied_names = model.get_embeddings_lm_head_tied_names() @@ -1138,14 +1174,18 @@ def mark_tied_parameters( ( target, ( - dpg.world_rank_matrix[ - get_pp_rank_of(target, module=model), dist.get_rank(dpg.dp_pg), dist.get_rank(dpg.tp_pg) + parallel_context.world_rank_matrix[ + get_pp_rank_of(target, module=model), + dist.get_rank(parallel_context.dp_pg), + dist.get_rank(parallel_context.tp_pg), ], ), ) for target in embeddings_lm_head_tied_names ] - tie_parameters(root_module=model, ties=shared_embeddings, dpg=dpg, reduce_op=dist.ReduceOp.SUM) + tie_parameters( + root_module=model, ties=shared_embeddings, parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM + ) # Sync all parameters that have the same name and that are not sharded assert not isinstance(model, DistributedDataParallel), "model shouldn't be DDP at this point" @@ -1174,7 +1214,13 @@ def mark_tied_parameters( ( name, # This adds all the tp_ranks in one go - tuple(sorted(dpg.world_rank_matrix[dist.get_rank(dpg.pp_pg), dist.get_rank(dpg.dp_pg), :])), + tuple( + sorted( + parallel_context.world_rank_matrix[ + dist.get_rank(parallel_context.pp_pg), dist.get_rank(parallel_context.dp_pg), : + ] + ) + ), ) ] @@ -1185,6 +1231,8 @@ def mark_tied_parameters( else: reduce_op = dist.ReduceOp.SUM - tie_parameters(root_module=model, ties=shared_weights, dpg=dpg, reduce_op=reduce_op) + tie_parameters( + root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=reduce_op + ) - create_pg_for_tied_weights(root_module=model, dpg=dpg) + create_pg_for_tied_weights(root_module=model, parallel_context=parallel_context) diff --git a/tests/helpers/dummy.py b/tests/helpers/dummy.py index dcd904d8..fe3dd9e5 100644 --- a/tests/helpers/dummy.py +++ b/tests/helpers/dummy.py @@ -11,8 +11,8 @@ from nanotron.core.parallel.pipeline_parallelism.p2p import P2P from nanotron.core.parallel.pipeline_parallelism.tensor_pointer import TensorPointer from nanotron.core.parallel.tied_parameters import tie_parameters -from nanotron.core.process_groups import DistributedProcessGroups from nanotron.core.utils import init_on_device_and_dtype +from nanotron.distributed import ParallelContext from torch import nn from torch.nn.parallel import DistributedDataParallel @@ -64,14 +64,14 @@ def forward(self, x: Union[torch.Tensor, TensorPointer]): return x -def init_dummy_model(dpg: DistributedProcessGroups, dtype: torch.dtype = torch.float) -> DummyModel: - p2p = P2P(pg=dpg.pp_pg, device=torch.device("cuda")) +def init_dummy_model(parallel_context: ParallelContext, dtype: torch.dtype = torch.float) -> DummyModel: + p2p = P2P(pg=parallel_context.pp_pg, device=torch.device("cuda")) model = DummyModel(p2p=p2p) # Build model using contiguous segments pipeline_blocks = [module for name, module in model.named_modules() if isinstance(module, PipelineBlock)] with init_on_device_and_dtype(device=torch.device("cuda"), dtype=dtype): - contiguous_size = ceil(len(pipeline_blocks) / dpg.pp_pg.size()) + contiguous_size = ceil(len(pipeline_blocks) / parallel_context.pp_pg.size()) for i, block in enumerate(pipeline_blocks): rank = i // contiguous_size block.build_and_set_rank(rank) @@ -84,15 +84,21 @@ def init_dummy_model(dpg: DistributedProcessGroups, dtype: torch.dtype = torch.f ( name, # This adds all the tp_ranks in one go - set(dpg.world_rank_matrix[dist.get_rank(dpg.pp_pg), dist.get_rank(dpg.dp_pg), :]), + set( + parallel_context.world_rank_matrix[ + dist.get_rank(parallel_context.pp_pg), dist.get_rank(parallel_context.dp_pg), : + ] + ), ) ] - tie_parameters(root_module=model, ties=shared_weights, dpg=dpg, reduce_op=dist.ReduceOp.SUM) + tie_parameters( + root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM + ) - initial_sync(model=model, dpg=dpg) + initial_sync(model=model, parallel_context=parallel_context) if len(list(model.named_parameters())) > 0: - model = DistributedDataParallel(model, process_group=dpg.dp_pg) + model = DistributedDataParallel(model, process_group=parallel_context.dp_pg) else: # No parameters, so no need to use DDP to sync parameters gradients model = model @@ -100,7 +106,7 @@ def init_dummy_model(dpg: DistributedProcessGroups, dtype: torch.dtype = torch.f return model -def init_dummy_optimizer(model: nn.Module, dpg: DistributedProcessGroups) -> BaseOptimizer: +def init_dummy_optimizer(model: nn.Module, parallel_context: ParallelContext) -> BaseOptimizer: optimizer = NamedOptimizer( named_params_or_groups=model.named_parameters(), optimizer_builder=lambda params: torch.optim.AdamW(params) ) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 6ca3e81a..7b7883aa 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -72,7 +72,7 @@ def __init__(self, func, args, kwargs, tp: int, dp: int, pp: int): def __call__(self): with mock_os_environ(update_key_values={"WORLD_SIZE": f"{self.tp * self.dp * self.pp}"}): - # dpg = get_process_groups( + # parallel_context = get_process_groups( # data_parallel_size=self.dp, # pipeline_parallel_size=self.pp, # tensor_parallel_size=self.tp, @@ -94,7 +94,7 @@ def init_distributed(tp: int, dp: int, pp: int): def _init_distributed(func): """Wrapper to help initialize distributed nanotron. - :param func: parallel function that runs on all the process, it requires one of its keyword argument to be "dpg" + :param func: parallel function that runs on all the process, it requires one of its keyword argument to be "parallel_context" """ nb_gpus = tp * dp * pp run_id = uuid.uuid4() diff --git a/tests/pytest.ini b/tests/pytest.ini index 05aa1de7..66cfb528 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -1,2 +1,2 @@ [pytest] -addopts=-n 24 \ No newline at end of file +addopts=-n 35 diff --git a/tests/test_clip_grads.py b/tests/test_clip_grads.py index 08b020b1..02d0c335 100644 --- a/tests/test_clip_grads.py +++ b/tests/test_clip_grads.py @@ -355,13 +355,13 @@ def _test_clip_grads_tied_weights(parallel_context: ParallelContext, norm_type: tie_parameters( root_module=model, ties=[("dense0.weight", (0,)), ("dense1.weight", (1,))], - dpg=parallel_context, + parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM, ) tie_parameters( root_module=model, ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))], - dpg=parallel_context, + parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM, ) @@ -382,7 +382,7 @@ def _test_clip_grads_tied_weights(parallel_context: ParallelContext, norm_type: assert bias.is_tied # Sync tied weights: basic assumption - initial_sync(model=model, dpg=parallel_context) + initial_sync(model=model, parallel_context=parallel_context) # Check that weights are now synced assert_tensor_synced_across_pg(weight, group) @@ -397,7 +397,7 @@ def _test_clip_grads_tied_weights(parallel_context: ParallelContext, norm_type: out.sum().backward() # sync gradients - sync_tied_weights_gradients(model, dpg=parallel_context, grad_accumulator=None) + sync_tied_weights_gradients(model, parallel_context=parallel_context, grad_accumulator=None) # We check that we both gradients are synchronized assert_tensor_synced_across_pg(weight.grad, group) diff --git a/tests/test_parameters_accumulate_gradient_in_fp32.py b/tests/test_parameters_accumulate_gradient_in_fp32.py index 0f4333d5..c88672f6 100644 --- a/tests/test_parameters_accumulate_gradient_in_fp32.py +++ b/tests/test_parameters_accumulate_gradient_in_fp32.py @@ -349,7 +349,7 @@ def _test_tied_weights_sync_with_grad_accum_in_fp32( f"mlp.{pp_rank}.linear.pp_block.weight" for pp_rank in range(parallel_context.pp_pg.size()) ] ], - dpg=parallel_context, + parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM, ) @@ -358,7 +358,7 @@ def _test_tied_weights_sync_with_grad_accum_in_fp32( module.bias = NanotronParameter(module.bias) # Sync DP and tied weights: basic assumption - initial_sync(model=mdl, dpg=parallel_context) + initial_sync(model=mdl, parallel_context=parallel_context) # Sync params between `model` and `reference_model` with torch.no_grad(): @@ -532,7 +532,9 @@ def forward_backward_reference(mdl, micro_batch): # - Translate tied ranks along DP axis to find the DP rank that has the tied weights # - accumulator keeps grads for all DPs, so we can just sync the grads with timeout_after(): - sync_tied_weights_gradients(module=model_ddp.module, dpg=parallel_context, grad_accumulator=accumulator) + sync_tied_weights_gradients( + module=model_ddp.module, parallel_context=parallel_context, grad_accumulator=accumulator + ) tied_infos_dict = { ( diff --git a/tests/test_serialize.py b/tests/test_serialize.py index c66911bd..760d1e31 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -55,14 +55,14 @@ def test_save_and_load_model(tp: int, dp: int, pp: int): def _test_save_and_load_model(parallel_context: ParallelContext, test_context: TestContext): - model = init_dummy_model(dpg=parallel_context) + model = init_dummy_model(parallel_context=parallel_context) store_folder = test_context.get_auto_remove_tmp_dir() # Save - save_weights(model=model, dpg=parallel_context, root_folder=store_folder) + save_weights(model=model, parallel_context=parallel_context, root_folder=store_folder) # Load - new_model = init_dummy_model(dpg=parallel_context) + new_model = init_dummy_model(parallel_context=parallel_context) # Check that the newly initialised model isn't the same. match, msg = is_dict_equal(new_model.state_dict(), model.state_dict()) @@ -72,7 +72,7 @@ def _test_save_and_load_model(parallel_context: ParallelContext, test_context: T else: assert not match, "Newly initialised model should not match." - load_weights(model=new_model, dpg=parallel_context, root_folder=store_folder) + load_weights(model=new_model, parallel_context=parallel_context, root_folder=store_folder) # Assert the weights are exactly the same after loading match, msg = is_dict_equal(new_model.state_dict(), model.state_dict()) @@ -95,7 +95,7 @@ def test_save_and_load_optimizer(tp: int, dp: int, pp: int): def _test_save_and_load_optimizer(parallel_context: ParallelContext, test_context: TestContext): store_folder = test_context.get_auto_remove_tmp_dir() - model = init_dummy_model(dpg=parallel_context) + model = init_dummy_model(parallel_context=parallel_context) optimizer = NamedOptimizer( named_params_or_groups=model.named_parameters(), optimizer_builder=lambda params: torch.optim.AdamW(params), @@ -111,13 +111,13 @@ def _test_save_and_load_optimizer(parallel_context: ParallelContext, test_contex model=model, pg=parallel_context.pp_pg, batch=[minibatch], nb_microbatches=1, grad_accumulator=None ) # Manually sync tied parameters - sync_tied_weights_gradients(module=model, dpg=parallel_context, grad_accumulator=None) + sync_tied_weights_gradients(module=model, parallel_context=parallel_context, grad_accumulator=None) # Optimizer steps optimizer.step() optimizer.zero_grad() # Save optimizer - save_optimizer(optimizer=optimizer, dpg=parallel_context, root_folder=store_folder) + save_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=store_folder) dist.barrier(parallel_context.world_pg) # Generate a new optimizer @@ -134,7 +134,7 @@ def _test_save_and_load_optimizer(parallel_context: ParallelContext, test_contex else: assert not match, "Newly initialised optimizer should not match." - load_optimizer(optimizer=new_optimizer, dpg=parallel_context, root_folder=store_folder) + load_optimizer(optimizer=new_optimizer, parallel_context=parallel_context, root_folder=store_folder) # Assert the optimizer states are exactly the same after loading. match, msg = is_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) @@ -157,7 +157,7 @@ def test_save_zero_optimizer_and_load_optimizer(tp: int, dp: int, pp: int): def _test_save_zero_optimizer_and_load_optimizer(parallel_context: ParallelContext, test_context: TestContext): store_folder = test_context.get_auto_remove_tmp_dir() - model = init_dummy_model(dpg=parallel_context) + model = init_dummy_model(parallel_context=parallel_context) optimizer = ZeroDistributedOptimizer( named_params_or_groups=model.named_parameters(), optimizer_builder=lambda named_param_groups: NamedOptimizer( @@ -177,13 +177,13 @@ def _test_save_zero_optimizer_and_load_optimizer(parallel_context: ParallelConte model=model, pg=parallel_context.pp_pg, batch=[minibatch], nb_microbatches=1, grad_accumulator=None ) # Manually sync tied parameters - sync_tied_weights_gradients(module=model, dpg=parallel_context, grad_accumulator=None) + sync_tied_weights_gradients(module=model, parallel_context=parallel_context, grad_accumulator=None) # Optimizer steps optimizer.step() optimizer.zero_grad() # Save optimizer - save_optimizer(optimizer=optimizer, dpg=parallel_context, root_folder=store_folder) + save_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=store_folder) dist.barrier(parallel_context.world_pg) # Generate a new optimizer @@ -204,7 +204,7 @@ def _test_save_zero_optimizer_and_load_optimizer(parallel_context: ParallelConte else: assert not match, "Newly initialised optimizer should not match." - load_optimizer(optimizer=new_optimizer, dpg=parallel_context, root_folder=store_folder) + load_optimizer(optimizer=new_optimizer, parallel_context=parallel_context, root_folder=store_folder) # Assert the optimizer states are exactly the same after loading. match, msg = is_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) @@ -232,7 +232,7 @@ def _test_save_zero_optimizer_and_load_data_parallel_optimizer( parallel_context: ParallelContext, test_context: TestContext ): store_folder = test_context.get_auto_remove_tmp_dir() - model = init_dummy_model(dpg=parallel_context) + model = init_dummy_model(parallel_context=parallel_context) optimizer = ZeroDistributedOptimizer( named_params_or_groups=model.named_parameters(), optimizer_builder=lambda named_param_groups: NamedOptimizer( @@ -252,13 +252,13 @@ def _test_save_zero_optimizer_and_load_data_parallel_optimizer( model=model, pg=parallel_context.pp_pg, batch=[minibatch], nb_microbatches=1, grad_accumulator=None ) # Manually sync tied parameters - sync_tied_weights_gradients(module=model, dpg=parallel_context, grad_accumulator=None) + sync_tied_weights_gradients(module=model, parallel_context=parallel_context, grad_accumulator=None) # Optimizer steps optimizer.step() optimizer.zero_grad() # Save optimizer - save_optimizer(optimizer=optimizer, dpg=parallel_context, root_folder=store_folder) + save_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=store_folder) dist.barrier(parallel_context.world_pg) # Generate a new optimizer @@ -275,7 +275,7 @@ def _test_save_zero_optimizer_and_load_data_parallel_optimizer( else: assert not match, "Newly initialised optimizer should not match." - load_optimizer(optimizer=new_optimizer, dpg=parallel_context, root_folder=store_folder) + load_optimizer(optimizer=new_optimizer, parallel_context=parallel_context, root_folder=store_folder) # TODO @thomasw21: Compare zero optimizer with non zero @@ -301,7 +301,7 @@ def _test_save_data_parallel_optimizer_and_load_zero_optimizer( parallel_context: ParallelContext, test_context: TestContext ): store_folder = test_context.get_auto_remove_tmp_dir() - model = init_dummy_model(dpg=parallel_context) + model = init_dummy_model(parallel_context=parallel_context) optimizer = NamedOptimizer( named_params_or_groups=model.named_parameters(), optimizer_builder=lambda params: torch.optim.AdamW(params), @@ -320,7 +320,7 @@ def _test_save_data_parallel_optimizer_and_load_zero_optimizer( optimizer.zero_grad() # Save optimizer - save_optimizer(optimizer=optimizer, dpg=parallel_context, root_folder=store_folder) + save_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=store_folder) dist.barrier(parallel_context.world_pg) # Generate a new optimizer @@ -341,7 +341,7 @@ def _test_save_data_parallel_optimizer_and_load_zero_optimizer( else: assert not match, "Newly initialised optimizer should not match." - load_optimizer(optimizer=new_optimizer, dpg=parallel_context, root_folder=store_folder) + load_optimizer(optimizer=new_optimizer, parallel_context=parallel_context, root_folder=store_folder) # TODO @thomasw21: Compare zero optimizer with non zero @@ -365,7 +365,7 @@ def test_save_optimizer_with_additional_state_dict_keys(tp: int, dp: int, pp: in def _test_save_optimizer_with_additional_state_dict_keys(parallel_context: ParallelContext, test_context: TestContext): dtype = torch.float16 store_folder = test_context.get_auto_remove_tmp_dir() - model = init_dummy_model(dpg=parallel_context, dtype=dtype) + model = init_dummy_model(parallel_context=parallel_context, dtype=dtype) if isinstance(model, DistributedDataParallel): # Remove the annoying "module." prefix @@ -401,13 +401,15 @@ def _test_save_optimizer_with_additional_state_dict_keys(parallel_context: Paral grad_accumulator=grad_accumulator, ) # Manually sync tied parameters - sync_tied_weights_gradients(module=normalized_model, dpg=parallel_context, grad_accumulator=grad_accumulator) + sync_tied_weights_gradients( + module=normalized_model, parallel_context=parallel_context, grad_accumulator=grad_accumulator + ) # Optimizer steps optimizer.step() optimizer.zero_grad() # Save optimizer - save_optimizer(optimizer=optimizer, dpg=parallel_context, root_folder=store_folder) + save_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=store_folder) dist.barrier(parallel_context.world_pg) # Generate a new optimizer @@ -428,7 +430,7 @@ def _test_save_optimizer_with_additional_state_dict_keys(parallel_context: Paral match, msg = is_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) assert not match, "Newly initialised optimizer should not match." - load_optimizer(optimizer=new_optimizer, dpg=parallel_context, root_folder=store_folder) + load_optimizer(optimizer=new_optimizer, parallel_context=parallel_context, root_folder=store_folder) # Assert the optimizer states are exactly the same after loading. match, msg = is_dict_equal(optimizer.state_dict()["state"], new_optimizer.state_dict()["state"]) @@ -486,10 +488,10 @@ def _test_save_and_load_random_states(parallel_context: ParallelContext, test_co assert random_states != random_statess[0] # save - save_random_states(random_states=random_states, dpg=parallel_context, root_folder=store_folder) + save_random_states(random_states=random_states, parallel_context=parallel_context, root_folder=store_folder) # load - new_random_states = load_random_states(dpg=parallel_context, root_folder=store_folder) + new_random_states = load_random_states(parallel_context=parallel_context, root_folder=store_folder) # Each rank has restored it's own random state assert random_states == new_random_states diff --git a/tests/test_tie_weights.py b/tests/test_tie_weights.py index 0b733657..16d794e5 100644 --- a/tests/test_tie_weights.py +++ b/tests/test_tie_weights.py @@ -24,13 +24,13 @@ def _test_tie_weight_in_same_device(parallel_context: ParallelContext): tie_parameters( root_module=model, ties=[("dense0.weight", (0,)), ("dense1.weight", (0,))], - dpg=parallel_context, + parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM, ) tie_parameters( root_module=model, ties=[("dense0.bias", (0,)), ("dense1.bias", (0,))], - dpg=parallel_context, + parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM, ) @@ -66,13 +66,13 @@ def _test_tie_weight_in_different_device(parallel_context: ParallelContext): tie_parameters( root_module=model, ties=[("dense0.weight", (0,)), ("dense1.weight", (1,))], - dpg=parallel_context, + parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM, ) tie_parameters( root_module=model, ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))], - dpg=parallel_context, + parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM, ) @@ -135,14 +135,14 @@ def _test_tie_weight_across_dp_is_impossible(parallel_context: ParallelContext): tie_parameters( root_module=model, ties=[("dense0.weight", (0,)), ("dense1.weight", (1,))], - dpg=parallel_context, + parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM, ) with assert_fail_with(AssertionError): tie_parameters( root_module=model, ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))], - dpg=parallel_context, + parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM, ) @@ -169,13 +169,13 @@ def _test_tie_weight_in_different_device_have_gradients_synchronized(parallel_co tie_parameters( root_module=model, ties=[("dense0.weight", (0,)), ("dense1.weight", (1,))], - dpg=parallel_context, + parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM, ) tie_parameters( root_module=model, ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))], - dpg=parallel_context, + parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM, ) @@ -209,7 +209,7 @@ def _test_tie_weight_in_different_device_have_gradients_synchronized(parallel_co # sync gradients # TODO @thomasw21: This should be done in hooks - sync_tied_weights_gradients(model, dpg=parallel_context, grad_accumulator=None) + sync_tied_weights_gradients(model, parallel_context=parallel_context, grad_accumulator=None) # Check that we have gradient assert weight.grad is not None diff --git a/tests/test_zero.py b/tests/test_zero.py index d7f6675e..2e60faf1 100644 --- a/tests/test_zero.py +++ b/tests/test_zero.py @@ -28,7 +28,7 @@ def test_zero_optimizer(tp: int, dp: int, pp: int): def _test_zero_optimizer(parallel_context: ParallelContext): - model = init_dummy_model(dpg=parallel_context) + model = init_dummy_model(parallel_context=parallel_context) optimizer = ZeroDistributedOptimizer( named_params_or_groups=model.named_parameters(), optimizer_builder=lambda named_param_groups: NamedOptimizer( @@ -40,7 +40,7 @@ def _test_zero_optimizer(parallel_context: ParallelContext): index_to_name = [name for name, _ in model.named_parameters()] # reference model - reference_model = init_dummy_model(dpg=parallel_context) + reference_model = init_dummy_model(parallel_context=parallel_context) reference_optimizer = torch.optim.AdamW(reference_model.parameters()) # sync weights between reference_model and model @@ -80,8 +80,8 @@ def _test_zero_optimizer(parallel_context: ParallelContext): ) # Manually sync tied parameters' gradients - sync_tied_weights_gradients(module=model, dpg=parallel_context, grad_accumulator=None) - sync_tied_weights_gradients(module=reference_model, dpg=parallel_context, grad_accumulator=None) + sync_tied_weights_gradients(module=model, parallel_context=parallel_context, grad_accumulator=None) + sync_tied_weights_gradients(module=reference_model, parallel_context=parallel_context, grad_accumulator=None) # We rely on DDP to synchronize gradients across DP. We only need to manually synchronize them if we don't use DDP. if not isinstance(model, DistributedDataParallel): @@ -332,8 +332,8 @@ def _test_zero_optimizer_with_tp( torch.testing.assert_close(loss, ref_loss, msg=lambda msg: f"At iteration {i}, {msg}") # Manually sync tied parameters - sync_tied_weights_gradients(module=model, dpg=parallel_context, grad_accumulator=None) - sync_tied_weights_gradients(module=reference_model, dpg=parallel_context, grad_accumulator=None) + sync_tied_weights_gradients(module=model, parallel_context=parallel_context, grad_accumulator=None) + sync_tied_weights_gradients(module=reference_model, parallel_context=parallel_context, grad_accumulator=None) # We rely on DDP to synchronize gradients across DP. We only need to manually synchronize them if we don't use DDP. if not isinstance(model, DistributedDataParallel): From af3acaf2dfda61ee3b0f276ace60cd8d5ac0546d Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 10 Jan 2024 12:25:39 +0000 Subject: [PATCH 4/9] refactor ParallelContext --- src/nanotron/distributed/parallel_context.py | 187 +------------------ src/nanotron/distributed/parallel_mode.py | 9 - tests/test_distributed.py | 35 ---- 3 files changed, 2 insertions(+), 229 deletions(-) delete mode 100644 src/nanotron/distributed/parallel_mode.py diff --git a/src/nanotron/distributed/parallel_context.py b/src/nanotron/distributed/parallel_context.py index f352649b..8f39ec55 100644 --- a/src/nanotron/distributed/parallel_context.py +++ b/src/nanotron/distributed/parallel_context.py @@ -1,46 +1,14 @@ import os -from typing import Dict, Literal, Tuple +from typing import Literal, Tuple import numpy as np import torch import torch.distributed as dist -from nanotron.distributed.parallel_mode import ParallelMode DistributedBackend = Literal["gloo", "mpi", "nccl"] -RanksToDevice = Dict[ParallelMode, int] class ParallelContext: - # @classmethod - # def from_torch( - # cls, - # tensor_parallel_size: int, - # pipeline_parallel_size: int, - # data_parallel_size: int, - # backend: DistributedBackend = "nccl", - # ): - # """Initialize parallel context based on the environment variables defined by torchrun.""" - # rank = int(os.environ["RANK"]) - # local_rank = int(os.environ["LOCAL_RANK"]) - # world_size = int(os.environ["WORLD_SIZE"]) - # local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) - # host = os.environ["MASTER_ADDR"] - # # TODO(xrsrke): make it auto search for ports? - # port = int(os.environ["MASTER_PORT"]) - - # return cls( - # rank=rank, - # local_rank=local_rank, - # world_size=world_size, - # local_world_size=local_world_size, - # host=host, - # port=port, - # backend=backend, - # tensor_parallel_size=tensor_parallel_size, - # pipeline_parallel_size=pipeline_parallel_size, - # data_parallel_size=data_parallel_size, - # ) - def __init__( self, tensor_parallel_size: int, @@ -71,22 +39,12 @@ def __init__( self.pipeline_parallel_size = pipeline_parallel_size self.data_parallel_size = data_parallel_size - # self._global_ranks = {} - # self._local_ranks = {} - # self._world_sizes = {} self._groups = {} - # self._ranks_in_group = {} - # self._ranks_to_device = {} - - # self.local_rank = local_rank - # self.local_world_size = local_world_size self.set_device() if not dist.is_initialized(): rank = int(os.environ["RANK"]) - # local_rank = int(os.environ["LOCAL_RANK"]) - # local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) host = os.environ["MASTER_ADDR"] # TODO(xrsrke): make it auto search for ports? port = int(os.environ["MASTER_PORT"]) @@ -116,23 +74,17 @@ def init_global_dist(self, rank: int, world_size: int, backend: DistributedBacke ranks=ranks, backend=dist.get_backend(), ) - # self._register_dist(rank, world_size, process_group, ranks_in_group=ranks, parallel_mode=ParallelMode.GLOBAL) - # self.add_group(ParallelMode.GLOBAL, process_group) self.world_pg = process_group - # self.add_global_rank(ParallelMode.GLOBAL, rank) def init_parallel_groups(self): """Initialize 3D parallelism's all process groups.""" - # rank = self.get_global_rank() - # NOTE: ensure all processes have joined the global group # before creating other groups dist.barrier(group=self.world_pg) - # rank = self.get_global_rank() rank = int(os.environ["RANK"]) - # world_size = self.get_world_size(ParallelMode.GLOBAL) world_size = int(os.environ["WORLD_SIZE"]) + ranks = np.arange(0, world_size).reshape( (self.pipeline_parallel_size, self.data_parallel_size, self.tensor_parallel_size) ) @@ -193,44 +145,11 @@ def init_parallel_groups(self): self.dp_pg = dp_pg self.pp_pg = pp_pg - # parallel_mode_to_pg = { - # ParallelMode.TENSOR: tp_pg, - # ParallelMode.PIPELINE: pp_pg, - # ParallelMode.DATA: dp_pg, - # } - # for parallel_mode in [ParallelMode.TENSOR, ParallelMode.PIPELINE, ParallelMode.DATA]: - # process_group = parallel_mode_to_pg[parallel_mode] - # # self.add_local_rank(parallel_mode, dist.get_rank(process_group)) - # # self.add_world_size(parallel_mode, dist.get_world_size(process_group)) - # self.add_group(parallel_mode, process_group) - # # self.add_ranks_in_group(parallel_mode, dist.get_process_group_ranks(process_group)) - - # TODO(xrsrke): remove world_rank_matrix, world_ranks_to_pg self.world_rank_matrix = ranks self.world_ranks_to_pg = world_ranks_to_pg dist.barrier() - # def _register_dist( - # self, - # local_rank: int, - # local_world_size: int, - # process_group: dist.ProcessGroup, - # ranks_in_group: List[int], - # parallel_mode: ParallelMode, - # ): - # """Register distributed group based on the parallel mode. - - # Args: - # local_rank (int): local rank - # local_world_size (int): local world size - # mode (ParallelMode): parallel mode - # """ - # self.add_local_rank(parallel_mode, local_rank) - # self.add_world_size(parallel_mode, local_world_size) - # self.add_group(parallel_mode, process_group) - # self.add_ranks_in_group(parallel_mode, ranks_in_group) - def set_device(self): local_rank = int(os.getenv("LOCAL_RANK", "0")) @@ -250,110 +169,8 @@ def map_rank_to_device(self): device_id = local_rank torch.cuda.set_device(torch.cuda.device(device_id)) - def is_initialized(self, parallel_mode: ParallelMode) -> bool: - """Check if the parallel mode is initialized. - - Args: - mode (ParallelMode): parallel mode - - Returns: - bool: True if the parallel mode is initialized, False otherwise - """ - return True if parallel_mode in self._groups else False - - # def get_global_rank(self) -> int: - # """Get the global rank of the local process.""" - # return self._global_ranks[ParallelMode.GLOBAL] - - # def add_global_rank(self, parallel_mode: ParallelMode, rank: int): - # """Add the global rank of the local process.""" - # self._global_ranks[parallel_mode] = rank - - # def get_local_rank(self, parallel_mode: ParallelMode) -> int: - # """Get the local rank of the local process in a given parallel mode.""" - # return self._local_ranks[parallel_mode] - - # def add_local_rank(self, parallel_mode: ParallelMode, rank: int): - # """Add the local rank of the local process in a given parallel mode.""" - # self._local_ranks[parallel_mode] = rank - - def get_global_rank_from_local_rank(self, local_rank: int, parallel_mode: ParallelMode) -> int: - """Get the global rank from a local rank in a given parallel mode.""" - process_group = self.get_group(parallel_mode) - return dist.get_global_rank(process_group, local_rank) - - # # TODO(xrsrke): add cache - # def get_world_size(self, parallel_mode: ParallelMode) -> int: - # """Get the world size of a given parallel mode.""" - # return self._world_sizes[parallel_mode] - - # def add_world_size(self, parallel_mode: ParallelMode, world_size: int): - # """Add the world size of a given parallel mode.""" - # self._world_sizes[parallel_mode] = world_size - - def add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup) -> int: - """Add a process group of a given parallel mode.""" - self._groups[parallel_mode] = group - - # TODO(xrsrke): add cache - def get_group(self, parallel_mode: ParallelMode) -> dist.ProcessGroup: - """Get a process group of a given parallel mode.""" - return self._groups[parallel_mode] - - # def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks_in_group: List[int]): - # """Add a list of global ranks in a given parallel mode of the local process.""" - # self._ranks_in_group[parallel_mode] = ranks_in_group - - # def get_ranks_in_group(self, parallel_mode: ParallelMode) -> List[int]: - # """A list of global ranks in a given parallel mode of the local process.""" - # return self._ranks_in_group[parallel_mode] - - # def get_next_local_rank(self, rank, parallel_mode: ParallelMode) -> int: - # """Get the next local rank in a given parallel mode.""" - # world_size = self.get_world_size(parallel_mode) - # return (rank + 1) % world_size - - # def get_prev_local_rank(self, rank, parallel_mode: ParallelMode) -> int: - # """Get the previous local rank in a given parallel mode.""" - # world_size = self.get_world_size(parallel_mode) - # return (rank - 1) % world_size - - # def is_first_rank(self, parallel_mode: ParallelMode) -> bool: - # local_rank = self.get_local_rank(parallel_mode) - # return local_rank == 0 - - # def is_last_rank(self, parallel_mode: ParallelMode) -> bool: - # local_rank = self.get_local_rank(parallel_mode) - # world_size = self.get_world_size(parallel_mode) - # return local_rank == world_size - 1 - def get_3d_ranks(self, world_rank: int) -> Tuple[int, int, int]: - # tp_world_size = self.get_world_size(ParallelMode.TENSOR) - # dp_world_size = self.get_world_size(ParallelMode.DATA) - # pp_world_size = self.get_world_size(ParallelMode.PIPELINE) - - # pp_rank = (world_rank // (tp_world_size * dp_world_size)) % pp_world_size - # dp_rank = (world_rank // tp_world_size) % dp_world_size - # tp_rank = world_rank % tp_world_size - # return (pp_rank, dp_rank, tp_rank) pp_rank = (world_rank // (self.tp_pg.size() * self.dp_pg.size())) % self.pp_pg.size() dp_rank = (world_rank // self.tp_pg.size()) % self.dp_pg.size() tp_rank = world_rank % self.tp_pg.size() return (pp_rank, dp_rank, tp_rank) - - def destroy(self): - assert self.is_initialized(ParallelMode.GLOBAL), "Global group must be initialized before destroying." - for mode, group in self._groups.items(): - # NOTE: we destroy the global group last - if mode is not ParallelMode.GLOBAL: - if self.is_initialized(mode) and self.get_world_size(mode) > 1: - # NOTE: only ranks in the parallel group need to synchronize - # before destroying the group - group = self.get_group(mode) - dist.barrier(group=group) - dist.destroy_process_group(group) - - dist.barrier() - dist.destroy_process_group() - - self._groups.clear() diff --git a/src/nanotron/distributed/parallel_mode.py b/src/nanotron/distributed/parallel_mode.py deleted file mode 100644 index 45419977..00000000 --- a/src/nanotron/distributed/parallel_mode.py +++ /dev/null @@ -1,9 +0,0 @@ -from enum import Enum - - -class ParallelMode(Enum): - GLOBAL = "global" - - TENSOR = "tensor" - PIPELINE = "pipeline" - DATA = "data" diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 506d8927..67f20682 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -9,46 +9,11 @@ def _test_init_parallel_context(parallel_context: ParallelContext): - # parallel_modes = [ - # ParallelMode.GLOBAL, - # ParallelMode.TENSOR, - # ParallelMode.PIPELINE, - # ParallelMode.DATA, - # ] - - # assert isinstance(parallel_context.get_global_rank(), int) - - # for parallel_mode in parallel_modes: - # # local_rank = parallel_context.get_local_rank(parallel_mode) - - # assert parallel_context.is_initialized(parallel_mode) is True - # # assert isinstance(parallel_context.get_group(parallel_mode), ProcessGroup) - - # # assert type(parallel_context.get_local_rank(parallel_mode)) == int - # # assert type(parallel_context.get_world_size(parallel_mode)) == int - - # # process_group = parallel_context.get_group(parallel_mode) - # # assert isinstance(process_group, ProcessGroup) - # # ranks_in_group = parallel_context.get_ranks_in_group(parallel_mode) - # # TODO(xrsrke): do an expected list of ranks - # # assert ranks_in_group == dist.get_process_group_ranks(process_group) - # # assert len(ranks_in_group) == parallel_context.get_world_size(parallel_mode) - - # # assert parallel_context.is_first_rank(parallel_mode) == (local_rank == 0) - # # assert parallel_context.is_last_rank(parallel_mode) == ( - # # local_rank == parallel_context.get_world_size(parallel_mode) - 1 - # # ) - assert isinstance(parallel_context.world_pg, ProcessGroup) assert isinstance(parallel_context.tp_pg, ProcessGroup) if parallel_context.tensor_parallel_size > 1 else True assert isinstance(parallel_context.pp_pg, ProcessGroup) if parallel_context.pipeline_parallel_size > 1 else True assert isinstance(parallel_context.dp_pg, ProcessGroup) if parallel_context.data_parallel_size > 1 else True - # parallel_context.destroy() - - # for parallel_mode in parallel_modes: - # assert parallel_context.is_initialized(parallel_mode) is False - @pytest.mark.parametrize( "tp,dp,pp", From 738f3648e3dbe9110da0b353b575b96cf6d7fb85 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 10 Jan 2024 12:31:51 +0000 Subject: [PATCH 5/9] remove mapping rank to device in ParallelContext --- src/nanotron/distributed/parallel_context.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/nanotron/distributed/parallel_context.py b/src/nanotron/distributed/parallel_context.py index 8f39ec55..af5828a2 100644 --- a/src/nanotron/distributed/parallel_context.py +++ b/src/nanotron/distributed/parallel_context.py @@ -159,16 +159,6 @@ def set_device(self): device_id = local_rank torch.cuda.set_device(torch.cuda.device(device_id)) - def map_rank_to_device(self): - """Map global rank to device.""" - local_rank = int(os.getenv("LOCAL_RANK", "0")) - - # NOTE: Set the device id. - # `torch.cuda.device_count` should return the number of device on a single node. - # We assume the nodes to be homogeneous (same number of gpus per node) - device_id = local_rank - torch.cuda.set_device(torch.cuda.device(device_id)) - def get_3d_ranks(self, world_rank: int) -> Tuple[int, int, int]: pp_rank = (world_rank // (self.tp_pg.size() * self.dp_pg.size())) % self.pp_pg.size() dp_rank = (world_rank // self.tp_pg.size()) % self.dp_pg.size() From 574f2b0119613c2677280ed1c456205b6c8568ba Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Thu, 11 Jan 2024 05:47:22 +0000 Subject: [PATCH 6/9] refactor --- README.md | 2 +- src/nanotron/core/process_groups.py | 151 ------------------- src/nanotron/distributed/__init__.py | 1 - src/nanotron/distributed/parallel_context.py | 5 +- src/nanotron/trainer.py | 4 +- tests/helpers/utils.py | 6 - 6 files changed, 5 insertions(+), 164 deletions(-) delete mode 100644 src/nanotron/core/process_groups.py diff --git a/README.md b/README.md index 22037c16..7eb7e0f9 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ Let's go through some key concepts. from nanotron.distributed import ParallelContext # define your topology -parallel_context = ParallelContext.from_torch( +parallel_context = ParallelContext( tensor_parallel_size=2, data_parallel_size=2, pipeline_parallel_size=2 diff --git a/src/nanotron/core/process_groups.py b/src/nanotron/core/process_groups.py deleted file mode 100644 index ca1434c5..00000000 --- a/src/nanotron/core/process_groups.py +++ /dev/null @@ -1,151 +0,0 @@ -import os -from dataclasses import dataclass, field -from typing import Dict, Tuple - -import numpy as np -import torch -from torch import distributed as torch_dist - -import nanotron.core.distributed as dist - - -@dataclass -class DistributedProcessGroups: - # Default process group, all the ranks are in the same process group - world_pg: dist.ProcessGroup - # Convention, dimensions are [pp,dp,tp] (with values equal to 1 when no parallelism) - world_rank_matrix: np.ndarray - - # process dependent process groups - tp_pg: dist.ProcessGroup - dp_pg: dist.ProcessGroup - pp_pg: dist.ProcessGroup - - # Mapping from sorted list of world ranks to process group - world_ranks_to_pg: Dict[Tuple[int, ...], dist.ProcessGroup] = field(default_factory=dict) - - def __repr__(self) -> str: - return ( - f"world_rank_matrix: ({dist.get_rank(self.world_pg)}, {self.world_rank_matrix.tolist()} (dimensions are [pp,dp,tp])) " - f"pp_pg: ({dist.get_rank(self.pp_pg)}, {dist.get_process_group_ranks(self.pp_pg)}) " - f"dp_pg: ({dist.get_rank(self.dp_pg)}, {dist.get_process_group_ranks(self.dp_pg)}) " - f"tp_pg: ({dist.get_rank(self.tp_pg)}, {dist.get_process_group_ranks(self.tp_pg)}) " - ) - - def get_3d_ranks(self, world_rank: int) -> Tuple[int, int, int]: - pp_rank = (world_rank // (self.tp_pg.size() * self.dp_pg.size())) % self.pp_pg.size() - dp_rank = (world_rank // self.tp_pg.size()) % self.dp_pg.size() - tp_rank = world_rank % self.tp_pg.size() - return (pp_rank, dp_rank, tp_rank) - - -def get_process_groups( - data_parallel_size: int = 1, - tensor_parallel_size: int = 1, - pipeline_parallel_size: int = 1, -) -> DistributedProcessGroups: - """ - Generate all the process groups necessary for training, and returning current ranks process groups. - - :param data_parallel_size: int - :param tensor_parallel_size: int - :param pipeline_parallel_size: int - :return: DistributedProcessGroups - """ - if not dist.is_available(): - raise ValueError("`torch.distributed is not available as a package, please install it.") - - if not dist.is_initialized(): - initialize_torch_distributed() - - world_pg = torch_dist.distributed_c10d._get_default_group() - world_size = world_pg.size() - world_rank = dist.get_rank(world_pg) - assert ( - world_size == data_parallel_size * tensor_parallel_size * pipeline_parallel_size - ), f"{world_size} != {data_parallel_size * tensor_parallel_size * pipeline_parallel_size}" - - # In the current implementation in DeepSpeed, tp then dp then pp - # https://cs.github.com/microsoft/DeepSpeed/blob/591744eba33f2ece04c15c73c02edaf384dca226/deepspeed/runtime/pipe/topology.py#L243 - - ranks = np.arange(0, world_size).reshape((pipeline_parallel_size, data_parallel_size, tensor_parallel_size)) - world_ranks_to_pg = {} - - tp_pg: dist.ProcessGroup - ranks_with_tp_last = ranks.reshape((pipeline_parallel_size * data_parallel_size, tensor_parallel_size)) - for tp_ranks in ranks_with_tp_last: - sorted_ranks = tuple(sorted(tp_ranks)) - if sorted_ranks not in world_ranks_to_pg: - new_group = dist.new_group(ranks=tp_ranks) - world_ranks_to_pg[sorted_ranks] = new_group - else: - new_group = world_ranks_to_pg[sorted_ranks] - if world_rank in tp_ranks: - tp_pg = new_group - - dp_pg: dist.ProcessGroup - ranks_with_dp_last = ranks.transpose((0, 2, 1)).reshape( - (pipeline_parallel_size * tensor_parallel_size, data_parallel_size) - ) - for dp_ranks in ranks_with_dp_last: - sorted_ranks = tuple(sorted(dp_ranks)) - if sorted_ranks not in world_ranks_to_pg: - new_group = dist.new_group(ranks=dp_ranks) - world_ranks_to_pg[sorted_ranks] = new_group - else: - new_group = world_ranks_to_pg[sorted_ranks] - if world_rank in dp_ranks: - dp_pg = new_group - - pp_pg: dist.ProcessGroup - ranks_with_pp_last = ranks.transpose((2, 1, 0)).reshape( - (tensor_parallel_size * data_parallel_size, pipeline_parallel_size) - ) - for pp_ranks in ranks_with_pp_last: - sorted_ranks = tuple(sorted(pp_ranks)) - if sorted_ranks not in world_ranks_to_pg: - new_group = dist.new_group(ranks=pp_ranks) - world_ranks_to_pg[sorted_ranks] = new_group - else: - new_group = world_ranks_to_pg[sorted_ranks] - if world_rank in pp_ranks: - pp_pg = new_group - - # We build model parallel group (combination of both tensor parallel and pipeline parallel) - for dp_rank in range(data_parallel_size): - pp_and_tp_ranks = ranks[:, dp_rank, :].reshape(-1) - sorted_ranks = tuple(sorted(pp_and_tp_ranks)) - if sorted_ranks not in world_ranks_to_pg: - new_group = dist.new_group(ranks=pp_and_tp_ranks) - world_ranks_to_pg[sorted_ranks] = new_group - - return DistributedProcessGroups( - world_pg=world_pg, - world_rank_matrix=ranks, - dp_pg=dp_pg, - tp_pg=tp_pg, - pp_pg=pp_pg, - world_ranks_to_pg=world_ranks_to_pg, - ) - - -def initialize_torch_distributed(): - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - local_rank = int(os.getenv("LOCAL_RANK", "0")) - - if torch.cuda.is_available(): - # Set the device id. - # `torch.cuda.device_count` should return the number of device on a single node. - # We assume the nodes to be homogeneous (same number of gpus per node) - device_id = local_rank - torch.cuda.set_device(torch.cuda.device(device_id)) - backend = "nccl" - else: - # TODO @thomasw21: Maybe figure out a way to do distributed `cpu` training at some point - raise NotImplementedError(f"CUDA was not found: torch.cuda.is_available(): {torch.cuda.is_available()}") - backend = "gloo" - - # Call the init process. - torch_dist.init_process_group(backend=backend, world_size=world_size, rank=rank, timeout=dist.default_pg_timeout) - return True diff --git a/src/nanotron/distributed/__init__.py b/src/nanotron/distributed/__init__.py index 33a98f44..9e4e5f8d 100644 --- a/src/nanotron/distributed/__init__.py +++ b/src/nanotron/distributed/__init__.py @@ -1,2 +1 @@ from nanotron.distributed.parallel_context import ParallelContext -from nanotron.distributed.parallel_mode import ParallelMode diff --git a/src/nanotron/distributed/parallel_context.py b/src/nanotron/distributed/parallel_context.py index af5828a2..bde2c390 100644 --- a/src/nanotron/distributed/parallel_context.py +++ b/src/nanotron/distributed/parallel_context.py @@ -50,8 +50,7 @@ def __init__( port = int(os.environ["MASTER_PORT"]) self.init_global_dist(rank, world_size, backend, host, port) - self.init_parallel_groups() - dist.barrier() + self._init_parallel_groups() def init_global_dist(self, rank: int, world_size: int, backend: DistributedBackend, host: str, port: int): """Initialize the global distributed group. @@ -76,7 +75,7 @@ def init_global_dist(self, rank: int, world_size: int, backend: DistributedBacke ) self.world_pg = process_group - def init_parallel_groups(self): + def _init_parallel_groups(self): """Initialize 3D parallelism's all process groups.""" # NOTE: ensure all processes have joined the global group # before creating other groups diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 599c263a..a188d645 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -170,8 +170,8 @@ def __init__(self, config_or_config_file: Union[Config, str]): group=self.parallel_context.world_pg, rank=None, ) - if free_mem < MIN_GPU_MEM_THRESHOLD: - raise RuntimeError(f"Not enough memory to train the model on node {os.environ.get('SLURMD_NODENAME')}") + # if free_mem < MIN_GPU_MEM_THRESHOLD: + # raise RuntimeError(f"Not enough memory to train the model on node {os.environ.get('SLURMD_NODENAME')}") # Try to allocate all the memory test_tensor_size = int(free_mem * 0.9) test_tensor = torch.zeros((test_tensor_size,), dtype=torch.uint8, device=torch.device("cuda")) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 7b7883aa..504ccd9e 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -72,12 +72,6 @@ def __init__(self, func, args, kwargs, tp: int, dp: int, pp: int): def __call__(self): with mock_os_environ(update_key_values={"WORLD_SIZE": f"{self.tp * self.dp * self.pp}"}): - # parallel_context = get_process_groups( - # data_parallel_size=self.dp, - # pipeline_parallel_size=self.pp, - # tensor_parallel_size=self.tp, - # ) - parallel_context = ParallelContext( data_parallel_size=self.dp, pipeline_parallel_size=self.pp, From bb03782d25b6dc7b707aea4fabbf9713ff422fc4 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Thu, 11 Jan 2024 14:04:22 +0000 Subject: [PATCH 7/9] refactor --- .pre-commit-config.yaml | 1 - src/nanotron/distributed/__init__.py | 2 ++ tests/test_distributed.py | 9 +++++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7becc218..a6045cfb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,4 +19,3 @@ repos: args: - --fix - --exit-non-zero-on-fix - exclude: ^src/nanotron/distributed/__init__.py$ diff --git a/src/nanotron/distributed/__init__.py b/src/nanotron/distributed/__init__.py index 9e4e5f8d..bd4e2b2e 100644 --- a/src/nanotron/distributed/__init__.py +++ b/src/nanotron/distributed/__init__.py @@ -1 +1,3 @@ from nanotron.distributed.parallel_context import ParallelContext + +__all__ = ["ParallelContext"] diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 67f20682..aa884e50 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -1,4 +1,6 @@ +import numpy as np import pytest +import torch.distributed as dist from helpers.utils import ( available_gpus, get_all_3d_configurations, @@ -14,6 +16,13 @@ def _test_init_parallel_context(parallel_context: ParallelContext): assert isinstance(parallel_context.pp_pg, ProcessGroup) if parallel_context.pipeline_parallel_size > 1 else True assert isinstance(parallel_context.dp_pg, ProcessGroup) if parallel_context.data_parallel_size > 1 else True + world_rank = dist.get_rank(parallel_context.world_pg) + ranks3d = parallel_context.get_3d_ranks(world_rank) + assert type(ranks3d) and len(ranks3d) + + assert isinstance(parallel_context.world_rank_matrix, np.ndarray) + assert isinstance(parallel_context.world_ranks_to_pg, dict) + @pytest.mark.parametrize( "tp,dp,pp", From fac49b7bbe6c13f63ae83041b0156e4986dc92ca Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Thu, 11 Jan 2024 14:09:05 +0000 Subject: [PATCH 8/9] fix tests --- tests/test_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index aa884e50..3e7cdc2f 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -18,7 +18,7 @@ def _test_init_parallel_context(parallel_context: ParallelContext): world_rank = dist.get_rank(parallel_context.world_pg) ranks3d = parallel_context.get_3d_ranks(world_rank) - assert type(ranks3d) and len(ranks3d) + assert isinstance(ranks3d, tuple) and len(ranks3d) assert isinstance(parallel_context.world_rank_matrix, np.ndarray) assert isinstance(parallel_context.world_ranks_to_pg, dict) From 490bc83381fa7a0bf6d4401549f1b7dbc30566c1 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 12 Jan 2024 07:50:24 +0000 Subject: [PATCH 9/9] uncomment memory checking --- src/nanotron/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index a188d645..599c263a 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -170,8 +170,8 @@ def __init__(self, config_or_config_file: Union[Config, str]): group=self.parallel_context.world_pg, rank=None, ) - # if free_mem < MIN_GPU_MEM_THRESHOLD: - # raise RuntimeError(f"Not enough memory to train the model on node {os.environ.get('SLURMD_NODENAME')}") + if free_mem < MIN_GPU_MEM_THRESHOLD: + raise RuntimeError(f"Not enough memory to train the model on node {os.environ.get('SLURMD_NODENAME')}") # Try to allocate all the memory test_tensor_size = int(free_mem * 0.9) test_tensor = torch.zeros((test_tensor_size,), dtype=torch.uint8, device=torch.device("cuda"))