diff --git a/yapt/trainer/trainer.py b/yapt/trainer/trainer.py index 332c16e..e2aa32e 100644 --- a/yapt/trainer/trainer.py +++ b/yapt/trainer/trainer.py @@ -355,7 +355,7 @@ def restore_exp(self): if is_dict(self._model.optimizer): for key in self._model.optimizer.keys(): - self._model.optimizer.load_state_dict( + self._model.optimizer[key].load_state_dict( checkpoint['optimizer_state_dict'][key]) else: self._model.optimizer.load_state_dict(