Skip to content

Commit

Permalink
Merge pull request #230 from TJ-Solergibert/fix_resume_pp
Browse files Browse the repository at this point in the history
Fix loading scheduler when having more than one param_group
  • Loading branch information
NouamaneTazi authored Nov 26, 2024
2 parents f6a7db3 + bd81b67 commit cfcdeae
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
19 changes: 12 additions & 7 deletions src/nanotron/serialize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ def optimizer_filename(parallel_context: ParallelContext, is_zero: bool):
return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"


def lr_scheduler_filename():
"""The lr_scheduler is the same for all processes."""
return f"{ObjectType.LR_SCHEDULER.value}.pt"
def lr_scheduler_filename(parallel_context: ParallelContext, is_zero: bool):
if is_zero is True:
return f"{ObjectType.LR_SCHEDULER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
else:
return f"{ObjectType.LR_SCHEDULER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"


def save_optimizer(
Expand Down Expand Up @@ -106,12 +108,13 @@ def convert_to_string(input_item):

def save_lr_scheduler(
lr_scheduler,
is_zero,
parallel_context: ParallelContext,
root_folder: Path,
):
"""Saves lr scheduler states"""
if dist.get_rank(parallel_context.world_pg) > 0:
# Only WORLD-RANK 0 saves the lr scheduler state
if not is_zero and dist.get_rank(parallel_context.dp_pg) > 0:
# this is Zero-0, so only DP-0 saves the optimizer states
return

root_folder = root_folder / "lr_scheduler"
Expand All @@ -120,7 +123,7 @@ def save_lr_scheduler(
# We dump the optimizer state using `torch.save`
torch.save(
lr_scheduler.state_dict(),
root_folder / lr_scheduler_filename(),
root_folder / lr_scheduler_filename(parallel_context, is_zero),
)


Expand Down Expand Up @@ -356,10 +359,12 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -

def load_lr_scheduler(
lr_scheduler,
is_zero,
parallel_context: ParallelContext,
root_folder: Path,
):
root_folder = root_folder / "lr_scheduler"

state_dict = torch.load(root_folder / lr_scheduler_filename())
state_dict = torch.load(root_folder / lr_scheduler_filename(parallel_context, is_zero))
lr_scheduler.load_state_dict(state_dict)
lr_scheduler._initial_step() # NOTE: this is required to set the initial learning rate
6 changes: 3 additions & 3 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def __init__(
if self.init_checkpoint_path is not None:
load_lr_scheduler(
lr_scheduler=self.lr_scheduler,
is_zero=self.config.optimizer.zero_stage,
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
)

Expand Down Expand Up @@ -906,9 +908,7 @@ def save_checkpoint(self) -> Path:
dist.get_rank(self.parallel_context.dp_pg) == 0
), # We only save the weights on DP==0
should_save_optimizer=True,
should_save_lr_scheduler=bool(
dist.get_rank(self.parallel_context.world_pg) == 0
), # We only save the lr_scheduler on world_rank==0
should_save_lr_scheduler=True,
should_save_config=bool(
dist.get_rank(self.parallel_context.world_pg) == 0
), # We only save the config on world_rank==0
Expand Down

1 comment on commit cfcdeae

@manuelbrack
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized that this commit breaks model saving for save in serialize.main does not use the required is_zero parameter.
https://github.com/huggingface/nanotron/blob/main/src/nanotron/serialize/main.py#L106

Please sign in to comment.