Skip to content

Commit

Permalink
added test to highlight previous bug
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Nov 21, 2024
1 parent 013a153 commit 9aedd84
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/nanotron/optim/named_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
for param in _params:
# https://github.com/pytorch/pytorch/issues/100701
assert param.numel() > 0

super().__init__(optimizer=optimizer_builder(params), id_to_name=id_to_name)

def state_dict(self) -> dict:
Expand All @@ -60,9 +59,16 @@ def state_dict(self) -> dict:
return optim_state_dict

def load_state_dict(self, state_dict: dict) -> None:
# TODO @thomasw21: Make a more robust test
assert set(self.id_to_name.values()) == set(
state_dict["names"].values()
), f"Elements don't match:\n - Elements in `self.id_to_name` that aren't in the other one: {set(self.id_to_name.values()) - set(state_dict['names'].values())}\n - Elements in `state_dict[\"names\"]` that aren't in the other one: {set(state_dict['names'].values()) - set(self.id_to_name.values())}"

OPTIMIZER_STATE_KEYS = sorted(state_dict["state"][0].keys() - {"step"})
assert len(state_dict["state"]) == len(state_dict["names"])
for key in OPTIMIZER_STATE_KEYS:
for k, state in state_dict["state"].items():
assert (
key in state
), f"Key {key} not found in state dict: {state} which corresponds to param_name: {state_dict['names'][k]}"

return super().load_state_dict(state_dict)

0 comments on commit 9aedd84

Please sign in to comment.