Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small fixes when resuming training #245

Merged
merged 6 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/nanotron/optim/named_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
for param in _params:
# https://github.com/pytorch/pytorch/issues/100701
assert param.numel() > 0

super().__init__(optimizer=optimizer_builder(params), id_to_name=id_to_name)

def state_dict(self) -> dict:
Expand All @@ -60,9 +59,16 @@ def state_dict(self) -> dict:
return optim_state_dict

def load_state_dict(self, state_dict: dict) -> None:
# TODO @thomasw21: Make a more robust test
assert set(self.id_to_name.values()) == set(
state_dict["names"].values()
), f"Elements don't match:\n - Elements in `self.id_to_name` that aren't in the other one: {set(self.id_to_name.values()) - set(state_dict['names'].values())}\n - Elements in `state_dict[\"names\"]` that aren't in the other one: {set(state_dict['names'].values()) - set(self.id_to_name.values())}"

OPTIMIZER_STATE_KEYS = sorted(state_dict["state"][0].keys() - {"step"})
assert len(state_dict["state"]) == len(state_dict["names"])
for key in OPTIMIZER_STATE_KEYS:
for k, state in state_dict["state"].items():
assert (
key in state
), f"Key {key} not found in state dict: {state} which corresponds to param_name: {state_dict['names'][k]}"

return super().load_state_dict(state_dict)
21 changes: 21 additions & 0 deletions src/nanotron/sanity_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def before_tbi_sanity_checks(
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Check that the model params are synchronized across dp
Expand All @@ -84,6 +85,17 @@ def before_tbi_sanity_checks(
msg=lambda err: f"[Before train] Tied weights {name} are not synchronized. {err}",
)

# SANITY CHECK: Check that model grads are zeroed or None
for name, param in unwrapped_model.named_parameters():
if param.grad is not None:
torch.testing.assert_close(
param.grad,
torch.zeros_like(param.grad),
atol=0,
rtol=0,
msg="Model half precision grads must be zeroed or None in first accumulation step.",
)

# SANITY CHECK: Check that the grad accumulator buffers are ready for DDP
if grad_accumulator is not None:
for _, elt in grad_accumulator.fp32_grad_buffers.items():
Expand All @@ -96,6 +108,15 @@ def before_tbi_sanity_checks(
msg="Grad accumulator buffers must be zeroed in first accumulation step.",
)

# TODO: add checks for memory contiguousness

# SANITY CHECK: Check that optimizer's lr is synchronized with lr_scheduler
for i, group in enumerate(lr_scheduler.optimizer.param_groups):
assert (
group["lr"] == lr_scheduler.get_last_lr()[i]
), f"Optimizer and LR scheduler are not in sync. Got {group['lr']} and {lr_scheduler.get_last_lr()[i]}"
break

# SANITY CHECK: run model specific sanity checks
unwrapped_model.before_tbi_sanity_checks()

Expand Down
59 changes: 44 additions & 15 deletions src/nanotron/serialize/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Optional, Tuple

Expand Down Expand Up @@ -147,6 +149,9 @@ def load_optimizer(
if int(ckp_tp_size) != int(parallel_context.tp_pg.size()) or int(ckp_pp_size) != int(
parallel_context.pp_pg.size()
):
warnings.warn(
"You are resuming in a different PP size, so optimizer states need to be checked. Feel free to open a PR if you work on this!"
)
assert (
param_shard_metadata is not None
), f"You have to pass how the original parameters are sharded in order to resume in a different tensor parallel size, ckp_tp_size: {ckp_tp_size}, current tp_size: {parallel_context.tp_pg.size()}"
Expand Down Expand Up @@ -174,18 +179,24 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
# NOTE: if the checkpoint is from a Zero-0 optimizer, then we don't need to merge the shards
# across data parallel dimension, just directly load the checkpoints
shard_paths = list(
root_folder.glob(f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_tp-*-of-{ckp_tp_size}.pt")
root_folder.glob(
f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_tp-*-of-{ckp_tp_size}.pt"
) # WARN: wildcard here after tp can hold `0-of-1_exp-0`
)

ckp_sharded_optim_states = {}
for shard_path in shard_paths:
pp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=False)
ckp_sharded_optim_states[(pp_rank, tp_rank)] = torch.load(shard_path, map_location=map_location)
ckp_sharded_optim_states[(pp_rank, tp_rank)] = torch.load(
shard_path, map_location=map_location
) # load all optim states in mem

model_state_dict = model.state_dict()
new_optim_state_dict = optimizer.state_dict()
new_optim_state_dict["state"] = defaultdict(dict)
# TODO: this does not handle the edge case of different pipeline parallel optimizer state shards saving different state keys
OPTIMIZER_STATE_NAMES = sorted(ckp_sharded_optim_states[(0, 0)]["state"][0].keys() - ["step"])
OPTIMIZER_STATE_DTYPE = ckp_sharded_optim_states[(0, 0)]["state"][0][OPTIMIZER_STATE_NAMES[0]].dtype
# NOTE: because we can only resume training with the same optimizer type
# (0, 0) = (pp_rank, tp_rank)
# NOTE: also we don't merge "step" because it's just a scalar
Expand Down Expand Up @@ -224,14 +235,14 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
# from an unsharded optimizer state's shape
new_shard_metadata = param.get_sharded_info()
new_unshared_shape = new_shard_metadata.unsharded_shape
new_optim_state_dict["state"][param_index] = {}
# NOTE: restore each state tensor (e.g. exg_avg) by iterating through
# the optimizer state shards saved using the previous topology
for state_key in OPTIMIZER_STATE_NAMES:
# TODO(xrsrke): free the memory of the shards that isn't
# corresponding to the current rank
buffer = torch.zeros_like(param, device="cuda")
unsharded_buffer = torch.empty(new_unshared_shape, device="cuda")
# TODO: maybe better to allocate memory for all states at once
buffer = torch.zeros_like(param, device="cuda", dtype=OPTIMIZER_STATE_DTYPE)
unsharded_buffer = torch.empty(new_unshared_shape, device="cuda", dtype=OPTIMIZER_STATE_DTYPE)

for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items():
old_optim_state_index = find_optim_index_from_param_name(
Expand Down Expand Up @@ -266,17 +277,34 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
],
new_shard_metadata,
)
else:
# Handle non-sharded params (e.g. layernorm)
for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items():
old_optim_state_index = find_optim_index_from_param_name(
base_name, ckp_sharded_optim_states, is_zero1=False, pp_rank=pp_rank
)
if old_optim_state_index is None:
continue # Param not in this PP shard

if ckp_optim_type == ZeroDistributedOptimizer.__name__:
# NOTE: flatten the optimizer states
new_optim_state_dict["state"][param_index][state_key] = new_optim_state_dict["state"][
param_index
][state_key].flatten()
# NOTE: a bit awkward, but while we're already reading this (pp,tp) shard for whatever state_key,
# try to get the step value as well.
step = ckp_optim_state["state"][old_optim_state_index].get("step")
if step is not None:
new_optim_state_dict["state"][param_index]["step"] = step
# For non-sharded params, just copy over the state directly
for state_key in OPTIMIZER_STATE_NAMES:
new_optim_state_dict["state"][param_index][state_key] = ckp_optim_state["state"][
old_optim_state_index
][state_key]

if ckp_optim_type == ZeroDistributedOptimizer.__name__:
# NOTE: flatten the optimizer states
new_optim_state_dict["state"][param_index][state_key] = new_optim_state_dict["state"][param_index][
state_key
].flatten()

# NOTE: a bit awkward, but while we're already reading this (pp,tp) shard for whatever state_key,
# try to get the step value as well.
step = ckp_optim_state["state"][old_optim_state_index].get("step")
if step is not None:
new_optim_state_dict["state"][param_index]["step"] = step

# NOTE: we throw away ckp_optim_state['gradient_accumulator'] which has fp32 grads

new_optim_state_dict["names"] = new_optim_state_param_names
state_dict = new_optim_state_dict
Expand Down Expand Up @@ -319,3 +347,4 @@ def load_lr_scheduler(

state_dict = torch.load(root_folder / lr_scheduler_filename())
lr_scheduler.load_state_dict(state_dict)
lr_scheduler._initial_step() # NOTE: this is required to set the initial learning rate
6 changes: 4 additions & 2 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
param_shard_metadata=self.param_shard_metadata,
model=self.model,
model=self.unwrapped_model,
)

# Init learning rate scheduler
Expand Down Expand Up @@ -470,7 +470,9 @@ def train(
def training_step(
self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]
) -> Tuple[Iterable[Dict], Optional[torch.Tensor]]:
before_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator)
before_tbi_sanity_checks(
self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.lr_scheduler
)

if self.iteration_step < 5:
log_memory(logger=logger)
Expand Down
Loading