Skip to content

Commit

Permalink
refactor: moved activation checkpointing to FSDP model factory
Browse files Browse the repository at this point in the history
  • Loading branch information
le1nux committed Sep 10, 2024
1 parent a9812f3 commit 5d29535
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ class FSDPWrappedModelConfig(BaseModel):
mixed_precision_settings: MixedPrecisionSettings
sharding_strategy: ShardingStrategy
block_names: List[str]
activation_checkpointing_modules: Optional[List[str]] = Field(default_factory=list)

@field_validator("mixed_precision_settings", mode="before")
def parse_mixed_precision_setting_by_name(cls, name):
Expand Down Expand Up @@ -326,7 +327,7 @@ class DummyProgressSubscriberConfig(BaseModel):
class RichProgressSubscriberConfig(BaseModel):
train_dataloader: PydanticLLMDataLoaderIFType
eval_dataloaders: Optional[List[PydanticLLMDataLoaderIFType]] = Field(default_factory=list)
global_num_seen_steps: int
num_seen_steps: int
global_rank: int
gradient_acc_steps: Annotated[int, Field(strict=True, gt=0)]

Expand Down
8 changes: 8 additions & 0 deletions src/modalities/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

from modalities.activation_checkpointing import apply_activation_checkpointing_inplace
from modalities.checkpointing.checkpoint_loading import CheckpointLoadingIF
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.running_env.env_utils import MixedPrecisionSettings
Expand Down Expand Up @@ -46,6 +47,7 @@ def get_fsdp_wrapped_model(
block_names: List[str],
mixed_precision_settings: MixedPrecisionSettings,
sharding_strategy: ShardingStrategy,
activation_checkpointing_modules: List[str],
) -> FSDP:
"""
Get the FSDP-wrapped model.
Expand Down Expand Up @@ -87,6 +89,12 @@ def get_fsdp_wrapped_model(
f"{get_local_number_of_trainable_parameters(fsdp_model)}"
)

if len(activation_checkpointing_modules) > 0:
apply_activation_checkpointing_inplace(
model=fsdp_model,
activation_checkpointing_modules=activation_checkpointing_modules,
)

return fsdp_model

@staticmethod
Expand Down

0 comments on commit 5d29535

Please sign in to comment.