From fc4e73370d84af5242996a90b32b3ffce8e6b922 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 14 Nov 2024 13:52:51 -0500 Subject: [PATCH] Add no_sync context manager (#6675) Fix #1902 --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- .../layer_container_base.py | 4 +- deepspeed/runtime/engine.py | 37 +++- deepspeed/runtime/zero/stage_1_and_2.py | 5 - tests/unit/runtime/test_no_sync_ctxt.py | 197 ++++++++++++++++++ 4 files changed, 229 insertions(+), 14 deletions(-) create mode 100644 tests/unit/runtime/test_no_sync_ctxt.py diff --git a/deepspeed/inference/v2/model_implementations/layer_container_base.py b/deepspeed/inference/v2/model_implementations/layer_container_base.py index f26c87556665..feb65b4a5f5d 100644 --- a/deepspeed/inference/v2/model_implementations/layer_container_base.py +++ b/deepspeed/inference/v2/model_implementations/layer_container_base.py @@ -14,7 +14,7 @@ # Currently have dependency loops for the type hints. InferenceModel = Type["InferenceModel"] -LayerContainer = Type["LayerContainer"] +LayerContainer = Type["LayerContainer"] # noqa: F811 MAPPING_KEY = "PARAM_MAPPING" PLIST_HELPERS = "_ds_plist_strip_vals" @@ -161,7 +161,7 @@ def __call__(cls, *args, **kwargs): return instance -class LayerContainer(metaclass=LayerMetaclass): +class LayerContainer(metaclass=LayerMetaclass): # noqa: F811 """ Abstract base class for containing model parameters. diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e1e745d2b112..8c5da36e5a78 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -17,6 +17,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from contextlib import contextmanager from typing import Callable, Dict, Union, Iterable, Container @@ -216,6 +217,7 @@ def __init__(self, self.loaded_checkpoint_mp_world_size = None self.loaded_checkpoint_dp_world_size = None self.enable_backward_allreduce = True + self.inside_no_sync_ctxt = False self.progressive_layer_drop = None self.eigenvalue = None self.block_eigenvalue = None @@ -1981,12 +1983,31 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): grads = None self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size) + @contextmanager + def no_sync(self): + r""" + Context manager to disable gradient reduction during backward pass. + This context manager has the following effects on other DeepSpeed features. + 1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning. + 2. It is illegal to call engine.step() within the context manager. + 3. Tracking of gradient accumulation steps is disabled. + """ + assert not self.zero_optimization_partition_gradients(), \ + f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}" + + assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported" + + self.inside_no_sync_ctxt = True + try: + yield + finally: + self.inside_no_sync_ctxt = False + @instrument_w_nvtx - def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_graph=False, scale_wrt_gas=True): + def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=True): r"""Execute backward pass on the loss Arguments: loss: Torch tensor on which to execute backward propagation - allreduce_gradients: is deprecated, ignored, and will soon be removed' retain_graph: bool, default: false forward on user defined choice of retain_graph """ @@ -1996,11 +2017,10 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_gr if self.scale_wrt_gas is not None: scale_wrt_gas = self.scale_wrt_gas - if not allreduce_gradients: - logger.warning(f"Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed") + do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt - # scale loss w.r.t. gradient accumulation if needed - if self.gradient_accumulation_steps() > 1 and scale_wrt_gas: + # scale loss w.r.t. gradient accumulation if reduction is not disabled + if do_gradient_reduction and self.gradient_accumulation_steps() > 1 and scale_wrt_gas: loss = self._scale_loss_by_gas(loss.float()) # Log training loss @@ -2049,7 +2069,7 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_gr self._start_timers(self.engine_timers.backward_reduce_timers) - if allreduce_gradients and self.enable_backward_allreduce: + if do_gradient_reduction: # Traditional code path that allreduces the module parameter grads self.allreduce_gradients() @@ -2185,6 +2205,9 @@ def step(self, lr_kwargs=None): r"""Execute the weight update step after forward and backward propagation on effective_train_batch. """ + assert not self.inside_no_sync_ctxt, \ + "It is illegal to call Engine.step() inside no_sync context manager" + see_memory_usage("Engine before step", force=self.memory_breakdown()) # Check early because self.global_steps is incremented at some point here. diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 669826206e4b..7ac89a233808 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -2297,11 +2297,6 @@ def load_state_dict(self, def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights): self.load_hp_checkpoint_state_from_checkpoint_dir("bit16_groups", checkpoint_folder) - @property - def param_groups(self): - """Forward the wrapped optimizer's parameters.""" - return self.optimizer.param_groups - def _load_global_state(self, sd): self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler) self.dynamic_loss_scale = sd.get('dynamic_loss_scale', self.dynamic_loss_scale) diff --git a/tests/unit/runtime/test_no_sync_ctxt.py b/tests/unit/runtime/test_no_sync_ctxt.py new file mode 100644 index 000000000000..8c6497013809 --- /dev/null +++ b/tests/unit/runtime/test_no_sync_ctxt.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest + +from contextlib import nullcontext +import torch + +from unit.simple_model import SimpleModel, random_dataloader +from unit.common import DistributedTest + +import deepspeed +import deepspeed.comm as dist +from deepspeed.utils import safe_get_full_grad + + +class TestNoSyncCtxt(DistributedTest): + world_size = 2 + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + @pytest.mark.parametrize("zero_stage", [0, 1, 2, 3]) + def test_zero_stage(self, zero_stage, dtype): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + } + + invalid_cfg = zero_stage > 1 + if dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + elif dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + + hidden_dim = 64 + total_samples = 32 + model = SimpleModel(hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=total_samples, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + dist.barrier() + + with pytest.raises(AssertionError) if invalid_cfg else nullcontext() as assertinfo: + with model.no_sync(): + for _, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + if invalid_cfg: + assert ("no_sync context manager is incompatible" in str(assertinfo)) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + @pytest.mark.parametrize("zero_stage", [0, 1]) + def test_engine_step(self, zero_stage, dtype): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + } + + if dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + elif dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + + hidden_dim = 64 + total_samples = 32 + model = SimpleModel(hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=total_samples, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + dist.barrier() + + with model.no_sync(): + for _, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + with pytest.raises(AssertionError) as assertinfo: + model.step() + assert ("It is illegal to call Engine.step() inside no_sync context manager" in str(assertinfo)) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + @pytest.mark.parametrize("zero_stage", [0, 1]) + def test_multiple_ctxts(self, zero_stage, dtype): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + } + + if dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + elif dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + + hidden_dim = 64 + total_samples = 32 + model = SimpleModel(hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=total_samples, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + dist.barrier() + + param_list = list(model.parameters()) + first_losses = [] + first_grad_norms = [] + with model.no_sync(): + for _, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + first_losses.append(loss.item()) + model.backward(loss) + grad_norm = sum([safe_get_full_grad(p).norm() for p in param_list]) + first_grad_norms.append(grad_norm.item()) + + second_losses = [] + second_grad_norms = [] + + model.zero_grad() + with model.no_sync(): + for _, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + second_losses.append(loss.item()) + model.backward(loss) + grad_norm = sum([safe_get_full_grad(p).norm() for p in param_list]) + second_grad_norms.append(grad_norm.item()) + + assert len(first_losses) == len(second_losses) + for x, y in zip(first_losses, second_losses): + assert x == y + + assert len(first_grad_norms) == len(second_grad_norms) + for x, y in zip(first_grad_norms, second_grad_norms): + assert x == y + + def test_reentry(self): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "zero_optimization": { + "stage": 1, + }, + } + + hidden_dim = 64 + model = SimpleModel(hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + dist.barrier() + + with model.no_sync(): + with pytest.raises(AssertionError) as assertinfo: + with model.no_sync(): + pass + assert ("no_sync context manager reentry is unsupported" in str(assertinfo))