diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 807d33b7..13efb68d 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -212,6 +212,7 @@ class FSDPWrappedModelConfig(BaseModel): mixed_precision_settings: MixedPrecisionSettings sharding_strategy: ShardingStrategy block_names: List[str] + activation_checkpointing_modules: Optional[List[str]] = Field(default_factory=list) @field_validator("mixed_precision_settings", mode="before") def parse_mixed_precision_setting_by_name(cls, name): @@ -326,7 +327,7 @@ class DummyProgressSubscriberConfig(BaseModel): class RichProgressSubscriberConfig(BaseModel): train_dataloader: PydanticLLMDataLoaderIFType eval_dataloaders: Optional[List[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) - global_num_seen_steps: int + num_seen_steps: int global_rank: int gradient_acc_steps: Annotated[int, Field(strict=True, gt=0)] diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 84c26d37..a7aa72a6 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -7,6 +7,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import ShardingStrategy +from modalities.activation_checkpointing import apply_activation_checkpointing_inplace from modalities.checkpointing.checkpoint_loading import CheckpointLoadingIF from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.running_env.env_utils import MixedPrecisionSettings @@ -46,6 +47,7 @@ def get_fsdp_wrapped_model( block_names: List[str], mixed_precision_settings: MixedPrecisionSettings, sharding_strategy: ShardingStrategy, + activation_checkpointing_modules: List[str], ) -> FSDP: """ Get the FSDP-wrapped model. @@ -87,6 +89,12 @@ def get_fsdp_wrapped_model( f"{get_local_number_of_trainable_parameters(fsdp_model)}" ) + if len(activation_checkpointing_modules) > 0: + apply_activation_checkpointing_inplace( + model=fsdp_model, + activation_checkpointing_modules=activation_checkpointing_modules, + ) + return fsdp_model @staticmethod