Skip to content

Commit

Permalink
fix non sharded optim states when loading checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Nov 21, 2024
1 parent e967e78 commit 013a153
Showing 1 changed file with 30 additions and 14 deletions.
44 changes: 30 additions & 14 deletions src/nanotron/serialize/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from collections import defaultdict
from pathlib import Path
from typing import Optional, Tuple

Expand Down Expand Up @@ -188,6 +189,7 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -

model_state_dict = model.state_dict()
new_optim_state_dict = optimizer.state_dict()
new_optim_state_dict["state"] = defaultdict(dict)
# TODO: this does not handle the edge case of different pipeline parallel optimizer state shards saving different state keys
OPTIMIZER_STATE_NAMES = sorted(ckp_sharded_optim_states[(0, 0)]["state"][0].keys() - ["step"])
OPTIMIZER_STATE_DTYPE = ckp_sharded_optim_states[(0, 0)]["state"][0][OPTIMIZER_STATE_NAMES[0]].dtype
Expand Down Expand Up @@ -229,7 +231,6 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
# from an unsharded optimizer state's shape
new_shard_metadata = param.get_sharded_info()
new_unshared_shape = new_shard_metadata.unsharded_shape
new_optim_state_dict["state"][param_index] = {}
# NOTE: restore each state tensor (e.g. exg_avg) by iterating through
# the optimizer state shards saved using the previous topology
for state_key in OPTIMIZER_STATE_NAMES:
Expand Down Expand Up @@ -272,19 +273,34 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
],
new_shard_metadata,
)

if ckp_optim_type == ZeroDistributedOptimizer.__name__:
# NOTE: flatten the optimizer states
new_optim_state_dict["state"][param_index][state_key] = new_optim_state_dict["state"][
param_index
][state_key].flatten()
# NOTE: a bit awkward, but while we're already reading this (pp,tp) shard for whatever state_key,
# try to get the step value as well.
step = ckp_optim_state["state"][old_optim_state_index].get("step")
if step is not None:
new_optim_state_dict["state"][param_index]["step"] = step

# NOTE: we throw away ckp_optim_state['gradient_accumulator'] which has fp32 grads
else:
# Handle non-sharded params (e.g. layernorm)
for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items():
old_optim_state_index = find_optim_index_from_param_name(
base_name, ckp_sharded_optim_states, is_zero1=False, pp_rank=pp_rank
)
if old_optim_state_index is None:
continue # Param not in this PP shard

# For non-sharded params, just copy over the state directly
for state_key in OPTIMIZER_STATE_NAMES:
new_optim_state_dict["state"][param_index][state_key] = ckp_optim_state["state"][
old_optim_state_index
][state_key]

if ckp_optim_type == ZeroDistributedOptimizer.__name__:
# NOTE: flatten the optimizer states
new_optim_state_dict["state"][param_index][state_key] = new_optim_state_dict["state"][param_index][
state_key
].flatten()

# NOTE: a bit awkward, but while we're already reading this (pp,tp) shard for whatever state_key,
# try to get the step value as well.
step = ckp_optim_state["state"][old_optim_state_index].get("step")
if step is not None:
new_optim_state_dict["state"][param_index]["step"] = step

# NOTE: we throw away ckp_optim_state['gradient_accumulator'] which has fp32 grads

new_optim_state_dict["names"] = new_optim_state_param_names
state_dict = new_optim_state_dict
Expand Down

0 comments on commit 013a153

Please sign in to comment.