From e660d2361bc806bd11f70f7b524ea134824a5171 Mon Sep 17 00:00:00 2001 From: hasan-yaman Date: Wed, 6 Nov 2024 16:21:47 +0300 Subject: [PATCH] fix model loading (#430) --- d3rlpy/optimizers/optimizers.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/d3rlpy/optimizers/optimizers.py b/d3rlpy/optimizers/optimizers.py index 2b965251..0852a467 100644 --- a/d3rlpy/optimizers/optimizers.py +++ b/d3rlpy/optimizers/optimizers.py @@ -5,6 +5,7 @@ from torch.optim import SGD, Adam, AdamW, Optimizer, RMSprop from torch.optim.lr_scheduler import LRScheduler +from ..logging import LOG from ..serializable_config import DynamicConfig, generate_config_registration from .lr_schedulers import LRSchedulerFactory, make_lr_scheduler_field @@ -102,9 +103,15 @@ def state_dict(self) -> Mapping[str, Any]: } def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: - self._optim.load_state_dict(state_dict["optim"]) + if "optim" in state_dict: + self._optim.load_state_dict(state_dict["optim"]) + else: + LOG.warning("Skip loading optimizer state.") if self._lr_scheduler: - self._lr_scheduler.load_state_dict(state_dict["lr_scheduler"]) + if "lr_scheduler" in state_dict: + self._lr_scheduler.load_state_dict(state_dict["lr_scheduler"]) + else: + LOG.warning("Skip loading lr scheduler state.") @dataclasses.dataclass()