diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index ceb33c9ac7a8..bd65a3f8f702 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -140,7 +140,7 @@ jobs: - name: Install Colossal-AI run: | - BUILD_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v . pip install --no-cache-dir -r requirements/requirements-test.txt - name: Store Colossal-AI Cache diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index f8ca07d9731e..278f0f72f8b3 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -55,7 +55,7 @@ jobs: if: steps.check-avai.outputs.avai == 'true' run: | [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ - BUILD_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v . cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ pip install --no-cache-dir -r requirements/requirements-test.txt diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py index fc7e0b74179a..b2ba47f6762d 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py @@ -43,7 +43,7 @@ def zero_grad(self): dtype: torch.dtype @abstractmethod - def pre_backward(self, loss: Tensor) -> Tensor: + def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor: """Called before backward. Args: diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index 9e07bdebf8fa..8fb56aee4fce 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -85,13 +85,18 @@ def __init__( master_params.append(master_p) group["params"] = master_params - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): loss = self.mixed_precision.pre_backward(loss) - loss.backward(*args, **kwargs) + loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) - tensor.backward(grad) + torch.autograd.backward( + tensors=tensor, + grad_tensors=grad, + inputs=inputs, + retain_graph=retain_graph, + ) def zero_grad(self, *args, **kwargs): for p in self.working_to_master_map.keys(): diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index c757a878d97a..a85d9f808546 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -46,9 +46,9 @@ def __init__( growth_interval=growth_interval, ) - def backward(self, loss: Tensor, *args, **kwargs) -> None: + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None: scaled_loss = self.scale_loss(loss) - scaled_loss.backward(*args, **kwargs) + scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) def step(self, *args, **kwargs) -> Optional[float]: out = self.scaler.step(self.optim, *args, **kwargs) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index b3615d9294e0..4807ea5c7cd2 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -28,7 +28,7 @@ from colossalai.interface.optimizer import DistributedOptim from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed -from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule +from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization.fp8_hook import FP8Hook @@ -295,7 +295,7 @@ def __init__( self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 super().__init__(optim) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): r""" Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -313,8 +313,12 @@ def backward(self, loss: Tensor, *args, **kwargs): """ # Call the superclass backward method to compute gradients. +<<<<<<< HEAD with self.model._hook_context(): super().backward(loss, *args, **kwargs) +======= + super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs) +>>>>>>> [plugin] hybrid support zero bubble pipeline (#6060) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -323,7 +327,7 @@ def backward(self, loss: Tensor, *args, **kwargs): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -340,7 +344,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor): """ # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -520,7 +524,7 @@ def __init__( max_norm=max_norm, ) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): r""" Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -537,8 +541,12 @@ def backward(self, loss: Tensor, *args, **kwargs): None """ # Call the superclass backward method to compute gradients. +<<<<<<< HEAD with self.model._hook_context(): super().backward(loss, *args, **kwargs) +======= + super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs) +>>>>>>> [plugin] hybrid support zero bubble pipeline (#6060) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -547,7 +555,7 @@ def backward(self, loss: Tensor, *args, **kwargs): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -563,7 +571,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor): None """ # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -780,7 +788,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]: else: return - def backward(self, loss, retain_graph=False): + def backward(self, loss, inputs=None, retain_graph=False): """ Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -796,7 +804,7 @@ def backward(self, loss, retain_graph=False): None """ # Call the superclass backward method to compute gradients. - super().backward(loss, retain_graph) + super().backward(loss, inputs=inputs, retain_graph=retain_graph) if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -805,7 +813,7 @@ def backward(self, loss, retain_graph=False): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -821,7 +829,7 @@ def backward_by_grad(self, tensor, grad): None """ # Call the superclass backward_by_grad method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -1026,6 +1034,7 @@ def __init__( custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + scheduler_nodes: List = None, num_layers_per_stage: Optional[List[int]] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, @@ -1044,6 +1053,9 @@ def __init__( dist.get_world_size() % (tp_size * pp_size) == 0 ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + assert ( + not pp_style == "zbv" or scheduler_nodes is not None + ), f"scheduler_nodes must not be None when using zero bubble pipeline." if enable_sequence_parallelism: self.sequence_parallelism_mode = ( sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" @@ -1105,29 +1117,39 @@ def __init__( self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.stage_manager = None - self.schedule = None + self.scheduler = None self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" - assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" + assert ( + pp_style in ["interleaved", "zbv"] or num_model_chunks == 1 + ), "num_model_chunks must be 1 when using 1f1b" + assert ( + pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2 + ), "num_model_chunks must be 2 when using zero bubble pipeline" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" assert ( self.zero_stage <= 1 ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" + if pp_style == "zbv": + self.logger.warning( + """the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})""" + ) self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, - enable_interleave=(pp_style == "interleaved"), + enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"), + use_zbv=(pp_style == "zbv"), num_model_chunks=num_model_chunks, num_layers_per_stage=num_layers_per_stage, ) if pp_style == "interleaved": assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" - self.schedule = InterleavedSchedule( + self.scheduler = InterleavedSchedule( stage_manager=self.stage_manager, num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, @@ -1137,13 +1159,21 @@ def __init__( fp8_communication=fp8_communication, ) elif pp_style == "1f1b": - self.schedule = OneForwardOneBackwardSchedule( + self.scheduler = OneForwardOneBackwardSchedule( stage_manager=self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, fp8_communication=fp8_communication, ) + elif pp_style == "zbv": + self.scheduler = ZeroBubbleVPipeScheduler( + stage_manager=self.stage_manager, + schedule=scheduler_nodes, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + microbatch_size=microbatch_size, + ) else: raise NotImplementedError() if sequence_parallelism_mode == "ring_attn": @@ -1257,7 +1287,6 @@ def configure( # Replace with distributed implementation if exists optimizer = cast_to_distributed(optimizer) - if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: self.logger.warning( "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", @@ -1374,7 +1403,7 @@ def execute_pipeline( ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() with ctx, model._hook_context(): - outputs = self.schedule.forward_backward_step( + outputs = self.scheduler.forward_backward_step( model, data_iter, criterion, optimizer, return_loss, return_outputs ) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index a05bbea43d74..e7fa34160c17 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -287,7 +287,7 @@ def __init__( self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size) self.stage_manager = None - self.schedule = None + self.scheduler = None self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: @@ -311,7 +311,7 @@ def __init__( if pp_style == "interleaved": assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" - self.schedule = InterleavedSchedule( + self.scheduler = InterleavedSchedule( stage_manager=self.stage_manager, num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, @@ -320,7 +320,7 @@ def __init__( overlap_p2p=overlap_p2p, ) elif pp_style == "1f1b": - self.schedule = OneForwardOneBackwardSchedule( + self.scheduler = OneForwardOneBackwardSchedule( stage_manager=self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size, diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index a236434a55d6..c8cf3ec21360 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -49,11 +49,11 @@ def zero_grad(self, *args, **kwargs): """ self.optim.zero_grad(*args, **kwargs) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): """ Performs a backward pass on the loss. """ - loss.backward(*args, **kwargs) + loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 50cc965bb9c3..5cc32114daff 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -136,7 +136,11 @@ def is_last_stage(self, ignore_chunk: bool = False) -> bool: if not self.is_interleave or ignore_chunk: return self.stage == self.num_stages - 1 else: - return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 + # use zero bubble pipeline + if self.use_zbv: + return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1 + else: + return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 @property def num_stages(self) -> int: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ec517da4966f..f9897b8b757c 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -261,7 +261,9 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(ignore_chunk=True): + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.norm) + elif stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(module.norm) else: @@ -353,7 +355,9 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + elif stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.lm_head) return held_layers @@ -411,7 +415,9 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(self.model.score) + elif stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.score) return held_layers diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index dbaae66108fe..9111c3b5debd 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -381,7 +381,7 @@ def backward(self, loss: torch.Tensor): loss.backward() self._post_backward() - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False): raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.") @staticmethod diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index fdf2a497626f..ccd4634b5fe2 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -298,12 +298,14 @@ def backward(self, loss: torch.Tensor): loss = self.mix_precision_mixin.pre_backward(loss) self.module.backward(loss) - def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): + def backward_by_grad( + self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False + ): # This function is called except the last stage of pipeline parallel # It receives the scaled grad from the previous rank # No need to scale the grad again # Need to unscale when optimizing - grad = self.mix_precision_mixin.pre_backward_by_grad(grad) + grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph) self.module.backward_by_grad(tensor, grad) def _maybe_move_fp32_params(self): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index ed51c2bacafc..f140147f79b8 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -423,7 +423,7 @@ def _add_to_bucket(self, param, group_id): # torch.optim.Optimizer methods ################################ - def backward(self, loss, retain_graph=False): + def backward(self, loss, inputs=None, retain_graph=False): assert not ( self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and no_sync are not compatible" @@ -434,6 +434,7 @@ def backward(self, loss, retain_graph=False): ctx = nullcontext() if self._backward_context is None else self._backward_context() with ctx: loss.backward(retain_graph=retain_graph) + loss.backward(inputs=inputs, retain_graph=retain_graph) if not self.require_grad_sync: return @@ -444,14 +445,19 @@ def backward(self, loss, retain_graph=False): if self._overlap_communication: get_accelerator().synchronize() - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False): assert not ( self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) - torch.autograd.backward(tensor, grad) + torch.autograd.backward( + tensor, + grad, + inputs=inputs, + retain_graph=retain_graph, + ) if not self.require_grad_sync: return diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 0f9ec601387b..3a8057c1fc30 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -310,8 +310,16 @@ def check_output_hidden_state( ): org_hidden_state = org_output.last_hidden_state - if stage_manager and stage_manager.is_last_stage(ignore_chunk=True): - sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + if stage_manager: + if stage_manager.use_zbv: + if stage_manager.is_first_stage(ignore_chunk=True): + sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + else: + sharded_hidden_state = sharded_output.last_hidden_state + elif stage_manager.is_last_stage(ignore_chunk=True): + sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + else: + sharded_hidden_state = sharded_output.last_hidden_state else: sharded_hidden_state = sharded_output.last_hidden_state @@ -388,7 +396,6 @@ def get_grad_tensors_for_check( pass if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") - grad_to_check[suffix] = { "org_grad": org_grad.float(), "shard_grad": shard_grad.float(), diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index e8f7916972a5..f3b4db1cefc1 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -7,6 +7,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.schedule.v_schedule import PipelineGraph from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter @@ -33,7 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) if enable_gradient_checkpointing: # org_model.gradient_checkpointing_enable() - sharded_model.unwrap().gradient_checkpointing_enable() + sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster @@ -112,12 +113,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, sharded_optimizer.step() # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True): + check_flag = False + if stage_manager is None: + check_flag = True + else: + if stage_manager.use_zbv: + if stage_manager.is_first_stage(ignore_chunk=True): + check_flag = True + elif stage_manager.is_last_stage(ignore_chunk=True): + check_flag = True + if check_flag: if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == "LlamaModel": check_output_hidden_state( org_output, @@ -270,10 +279,39 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, + { + "tp_size": 2, + "pp_size": 2, + "pp_style": "zbv", + "num_model_chunks": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "precision": "fp16", + "zero_stage": 0, + "initial_scale": 1, + "enable_gradient_checkpointing": True, + "parallel_output": False, + }, ], ) def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + if test_config.get("pp_style", None) == "zbv": + mem_f = 34 * 32 + 5 * 4 * 16 + mem_w = -32 * 32 + mem_b = -mem_w - mem_f + scheduler_nodes = PipelineGraph( + n_stage=test_config["pp_size"], + n_micro=test_config["num_microbatches"], + f_cost=1000, + b_cost=1000, + w_cost=1000, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + ).get_v_schedule() + test_config["scheduler_nodes"] = scheduler_nodes for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: continue