diff --git a/.gitignore b/.gitignore index 348562a3..d340b158 100644 --- a/.gitignore +++ b/.gitignore @@ -14,10 +14,11 @@ docs/d3rlpy*.rst docs/modules.rst docs/references/generated coverage.xml -.coverage +.coverage* .mypy_cache .ipynb_checkpoints build dist /.idea/ *.egg-info +*.DS_Store \ No newline at end of file diff --git a/d3rlpy/algos/qlearning/torch/callbacks.py b/d3rlpy/algos/qlearning/torch/callbacks.py new file mode 100644 index 00000000..09fe0dff --- /dev/null +++ b/d3rlpy/algos/qlearning/torch/callbacks.py @@ -0,0 +1,73 @@ +from abc import ABCMeta, abstractmethod +from typing import Sequence, List +import torch.nn as nn + +from ... import QLearningAlgoBase, QLearningAlgoImplBase +from ....constants import IMPL_NOT_INITIALIZED_ERROR + +__all__ = [ + "ParameterReset" +] + +class QLearningCallback(metaclass=ABCMeta): + @abstractmethod + def __call__(self, algo: QLearningAlgoBase, epoch: int, total_step: int): + pass + + +class ParameterReset(QLearningCallback): + def __init__(self, replay_ratio: int, encoder_reset:Sequence[bool], + output_reset:bool, algo:QLearningAlgoBase=None) -> None: + self._replay_ratio = replay_ratio + self._encoder_reset = encoder_reset + self._output_reset = output_reset + self._check = False + if algo is not None: + self._check_layer_resets(algo=algo) + + + def _get_layers(self, q_func:nn.ModuleList)->List[nn.Module]: + all_modules = {nm:module for (nm, module) in q_func.named_modules()} + q_func_layers = [ + *all_modules["_encoder._layers"], + all_modules["_fc"] + ] + return q_func_layers + + def _check_layer_resets(self, algo:QLearningAlgoBase): + assert algo._impl is not None, IMPL_NOT_INITIALIZED_ERROR + assert isinstance(algo._impl, QLearningAlgoImplBase) + + all_valid_layers = [] + for q_func in algo._impl.q_function: + q_func_layers = self._get_layers(q_func) + if len(self._encoder_reset) + 1 != len(q_func_layers): + raise ValueError( + f""" + q_function layers: {q_func_layers}; + specified encoder layers: {self._encoder_reset} + """ + ) + valid_layers = [ + hasattr(layer, 'reset_parameters') for lr, layer in zip( + self._encoder_reset, q_func_layers) + if lr + ] + all_valid_layers.append(all(valid_layers)) + self._check = all(all_valid_layers) + if not self._check: + raise ValueError( + "Some layer do not contain resettable parameters" + ) + + def __call__(self, algo: QLearningAlgoBase, epoch: int, total_step: int): + if not self._check: + self._check_layer_resets(algo=algo) + assert isinstance(algo._impl, QLearningAlgoImplBase) + if epoch % self._replay_ratio == 0: + reset_lst = [*self._encoder_reset, self._output_reset] + for q_func in algo._impl.q_function: + q_func_layers = self._get_layers(q_func) + for lr, layer in zip(reset_lst, q_func_layers): + if lr: + layer.reset_parameters() \ No newline at end of file diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index a531abd6..15525bed 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from torch.optim import Optimizer +from ....dataclass_utils import asdict_as_float from ....dataset import Shape from ....models.torch import ( ContinuousEnsembleQFunctionForwarder, @@ -14,10 +15,17 @@ build_squashed_gaussian_distribution, ) from ....torch_utility import TorchMiniBatch +from .ddpg_impl import DDPGCriticLoss from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules from .sac_impl import SACImpl, SACModules -__all__ = ["CQLImpl", "DiscreteCQLImpl", "CQLModules", "DiscreteCQLLoss"] +__all__ = [ + "CQLImpl", + "DiscreteCQLImpl", + "CQLModules", + "DiscreteCQLLoss", + "CQLLoss", +] @dataclasses.dataclass(frozen=True) @@ -26,6 +34,12 @@ class CQLModules(SACModules): alpha_optim: Optional[Optimizer] +@dataclasses.dataclass(frozen=True) +class CQLLoss(DDPGCriticLoss): + td_loss: torch.Tensor + conservative_loss: torch.Tensor + + class CQLImpl(SACImpl): _modules: CQLModules _alpha_threshold: float @@ -65,12 +79,28 @@ def __init__( def compute_critic_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor - ) -> torch.Tensor: - loss = super().compute_critic_loss(batch, q_tpn) + ) -> CQLLoss: + loss = super().compute_critic_loss(batch, q_tpn).loss conservative_loss = self._compute_conservative_loss( batch.observations, batch.actions, batch.next_observations ) - return loss + conservative_loss + return CQLLoss( + loss=loss + conservative_loss, + td_loss=loss, + conservative_loss=conservative_loss, + ) + + def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: + self._modules.critic_optim.zero_grad() + + q_tpn = self.compute_target(batch) + + loss = self.compute_critic_loss(batch, q_tpn) + + loss.loss.backward() + self._modules.critic_optim.step() + + return asdict_as_float(loss) def update_alpha(self, batch: TorchMiniBatch) -> Dict[str, float]: assert self._modules.alpha_optim @@ -274,5 +304,7 @@ def compute_loss( ) loss = td_loss + self._alpha * conservative_loss return DiscreteCQLLoss( - loss=loss, td_loss=td_loss, conservative_loss=conservative_loss + loss=loss, + td_loss=td_loss, + conservative_loss=self._alpha * conservative_loss, ) diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index a1535589..26a09d7e 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -6,13 +6,20 @@ from torch import nn from torch.optim import Optimizer +from ....dataclass_utils import asdict_as_float from ....dataset import Shape from ....models.torch import ContinuousEnsembleQFunctionForwarder, Policy from ....torch_utility import Modules, TorchMiniBatch, hard_sync, soft_sync from ..base import QLearningAlgoImplBase from .utility import ContinuousQFunctionMixin -__all__ = ["DDPGImpl", "DDPGBaseImpl", "DDPGBaseModules", "DDPGModules"] +__all__ = [ + "DDPGImpl", + "DDPGBaseImpl", + "DDPGBaseModules", + "DDPGModules", + "DDPGCriticLoss", +] @dataclasses.dataclass(frozen=True) @@ -24,6 +31,11 @@ class DDPGBaseModules(Modules): critic_optim: Optimizer +@dataclasses.dataclass(frozen=True) +class DDPGCriticLoss: + loss: torch.Tensor + + class DDPGBaseImpl( ContinuousQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta ): @@ -63,15 +75,15 @@ def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: loss = self.compute_critic_loss(batch, q_tpn) - loss.backward() + loss.loss.backward() self._modules.critic_optim.step() - return {"critic_loss": float(loss.cpu().detach().numpy())} + return asdict_as_float(loss) def compute_critic_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor - ) -> torch.Tensor: - return self._q_func_forwarder.compute_error( + ) -> DDPGCriticLoss: + loss = self._q_func_forwarder.compute_error( observations=batch.observations, actions=batch.actions, rewards=batch.rewards, @@ -79,6 +91,7 @@ def compute_critic_loss( terminals=batch.terminals, gamma=self._gamma**batch.intervals, ) + return DDPGCriticLoss(loss=loss) def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: # Q function should be inference mode for stability diff --git a/tests/algos/qlearning/test_callbacks.py b/tests/algos/qlearning/test_callbacks.py new file mode 100644 index 00000000..7bfec3fa --- /dev/null +++ b/tests/algos/qlearning/test_callbacks.py @@ -0,0 +1,113 @@ +import pytest +from typing import Any, Sequence, List, Union +from unittest.mock import MagicMock, Mock +from d3rlpy.dataset import Shape + +from d3rlpy.algos.qlearning.torch.callbacks import ParameterReset +from d3rlpy.algos import QLearningAlgoBase, QLearningAlgoImplBase +from d3rlpy.torch_utility import Modules +import torch + +from ...test_torch_utility import DummyModules + + +class LayerHasResetMock: + + def reset_parameters(self): + return True + +class LayerNoResetMock: + pass + +fc = torch.nn.Linear(100, 100) +optim = torch.optim.Adam(fc.parameters()) +modules = DummyModules(fc=fc, optim=optim) + +class ImplMock(MagicMock): + + def __init__( + self, q_funcs:List[Union[LayerHasResetMock, LayerNoResetMock]] + ) -> None: + super().__init__(spec=QLearningAlgoImplBase) + self.q_function = q_funcs + + +class QLearningAlgoBaseMock(MagicMock): + + def __init__(self, spec, layer_setup:Sequence[bool]) -> None: + super().__init__(spec=spec) + q_funcs = [] + for i in layer_setup: + if i: + q_funcs.append(LayerHasResetMock()) + else: + q_funcs.append(LayerNoResetMock()) + self._impl = ImplMock(q_funcs=q_funcs) + + + +def test_check_layer_resets(): + algo = QLearningAlgoBaseMock(spec=QLearningAlgoBase, + layer_setup=[True, True, False]) + replay_ratio = 2 + layer_reset_valid = [True, True, False] + pr = ParameterReset( + replay_ratio=replay_ratio, + layer_reset=layer_reset_valid, + algo=algo + ) + assert pr._check is True + + layer_reset_invalid = [True, True, True] + try: + pr = ParameterReset( + replay_ratio=replay_ratio, + layer_reset=layer_reset_invalid, + algo=algo + ) + raise Exception + except ValueError as e: + assert True + + layer_reset_long = [True, True, True, False] + try: + pr = ParameterReset( + replay_ratio=replay_ratio, + layer_reset=layer_reset_long, + algo=algo + ) + raise Exception + except ValueError as e: + assert True + + layer_reset_shrt = [True, True] + try: + pr = ParameterReset( + replay_ratio=replay_ratio, + layer_reset=layer_reset_shrt, + algo=algo + ) + raise Exception + except ValueError as e: + assert True + + +def test_call(): + algo = QLearningAlgoBaseMock(spec=QLearningAlgoBase, + layer_setup=[True, True, False]) + replay_ratio = 2 + layer_reset_valid = [True, True, False] + pr = ParameterReset( + replay_ratio=replay_ratio, + layer_reset=layer_reset_valid, + algo=algo + ) + pr(algo=algo, epoch=1, total_step=100) + pr(algo=algo, epoch=2, total_step=100) + + pr = ParameterReset( + replay_ratio=replay_ratio, + layer_reset=layer_reset_valid, + ) + pr(algo=algo, epoch=1, total_step=100) + pr(algo=algo, epoch=2, total_step=100) \ No newline at end of file