diff --git a/src/nanotron/optim/named_optimizer.py b/src/nanotron/optim/named_optimizer.py index 74214357..5f11710a 100644 --- a/src/nanotron/optim/named_optimizer.py +++ b/src/nanotron/optim/named_optimizer.py @@ -45,7 +45,6 @@ def __init__( for param in _params: # https://github.com/pytorch/pytorch/issues/100701 assert param.numel() > 0 - super().__init__(optimizer=optimizer_builder(params), id_to_name=id_to_name) def state_dict(self) -> dict: @@ -60,9 +59,16 @@ def state_dict(self) -> dict: return optim_state_dict def load_state_dict(self, state_dict: dict) -> None: - # TODO @thomasw21: Make a more robust test assert set(self.id_to_name.values()) == set( state_dict["names"].values() ), f"Elements don't match:\n - Elements in `self.id_to_name` that aren't in the other one: {set(self.id_to_name.values()) - set(state_dict['names'].values())}\n - Elements in `state_dict[\"names\"]` that aren't in the other one: {set(state_dict['names'].values()) - set(self.id_to_name.values())}" + OPTIMIZER_STATE_KEYS = sorted(state_dict["state"][0].keys() - {"step"}) + assert len(state_dict["state"]) == len(state_dict["names"]) + for key in OPTIMIZER_STATE_KEYS: + for k, state in state_dict["state"].items(): + assert ( + key in state + ), f"Key {key} not found in state dict: {state} which corresponds to param_name: {state_dict['names'][k]}" + return super().load_state_dict(state_dict)