From 4a6edc9530ffa8c33c0ec29db4057698b1449b09 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Fri, 4 Aug 2023 06:51:15 +0100 Subject: [PATCH 01/10] tracking of cql regularisation for continuous cql --- d3rlpy/algos/qlearning/cql.py | 3 ++- d3rlpy/algos/qlearning/torch/cql_impl.py | 18 +++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index 1825e1f3..4146b1de 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -206,8 +206,9 @@ def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: alpha_loss, alpha = self._impl.update_alpha(batch) metrics.update({"alpha_loss": alpha_loss, "alpha": alpha}) - critic_loss = self._impl.update_critic(batch) + critic_loss, cql_loss = self._impl.update_critic(batch) metrics.update({"critic_loss": critic_loss}) + metrics.update({"cql_loss": cql_loss}) actor_loss = self._impl.update_actor(batch) metrics.update({"actor_loss": actor_loss}) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index cf81b715..ed249a62 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -75,7 +75,23 @@ def compute_critic_loss( conservative_loss = self._compute_conservative_loss( batch.observations, batch.actions, batch.next_observations ) - return loss + conservative_loss + return loss + conservative_loss, conservative_loss + + @train_api + def update_critic(self, batch: TorchMiniBatch) -> float: + self._critic_optim.zero_grad() + + q_tpn = self.compute_target(batch) + + loss, cql_loss = self.compute_critic_loss(batch, q_tpn) + + loss.backward() + self._critic_optim.step() + + res = np.array( + [loss.cpu().detach().numpy(), cql_loss.cpu().detach().numpy()] + ) + return res @train_api def update_alpha(self, batch: TorchMiniBatch) -> Tuple[float, float]: From 5b7185cf3b8565f5a55f5aa96d83efc48cf16bbe Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Fri, 4 Aug 2023 11:32:29 +0100 Subject: [PATCH 02/10] updated for linting and formatting --- d3rlpy/algos/qlearning/torch/cql_impl.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index ed249a62..491c2caf 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -76,9 +76,9 @@ def compute_critic_loss( batch.observations, batch.actions, batch.next_observations ) return loss + conservative_loss, conservative_loss - + @train_api - def update_critic(self, batch: TorchMiniBatch) -> float: + def update_critic(self, batch: TorchMiniBatch) -> np.array: self._critic_optim.zero_grad() q_tpn = self.compute_target(batch) @@ -88,9 +88,9 @@ def update_critic(self, batch: TorchMiniBatch) -> float: loss.backward() self._critic_optim.step() - res = np.array( - [loss.cpu().detach().numpy(), cql_loss.cpu().detach().numpy()] - ) + critic_loss = float(loss.cpu().detach().numpy()) + cql_loss = float(cql_loss.cpu().detach().numpy()) + res = np.array([critic_loss, cql_loss]) return res @train_api From 7fd0a37805895d23eecd472a58fe458c09add657 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Thu, 10 Aug 2023 18:02:33 +0100 Subject: [PATCH 03/10] overwriting dr3 pull and aligning cql logging --- d3rlpy/algos/qlearning/cql.py | 4 ++-- d3rlpy/algos/qlearning/torch/cql_impl.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index 4146b1de..696eaf43 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -206,9 +206,9 @@ def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: alpha_loss, alpha = self._impl.update_alpha(batch) metrics.update({"alpha_loss": alpha_loss, "alpha": alpha}) - critic_loss, cql_loss = self._impl.update_critic(batch) + critic_loss, conservative_loss = self._impl.update_critic(batch) metrics.update({"critic_loss": critic_loss}) - metrics.update({"cql_loss": cql_loss}) + metrics.update({"conservative_loss": conservative_loss}) actor_loss = self._impl.update_actor(batch) metrics.update({"actor_loss": actor_loss}) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 491c2caf..4096b3e7 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -237,7 +237,8 @@ def compute_loss( conservative_loss = self._compute_conservative_loss( batch.observations, batch.actions.long() ) - return loss + self._alpha * conservative_loss, conservative_loss + cql_loss = self._alpha * conservative_loss + return loss + cql_loss, cql_loss def _compute_conservative_loss( self, obs_t: torch.Tensor, act_t: torch.Tensor From a52651eb6a1c6caaba086fe2be2523a4ddeb3eb0 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Thu, 10 Aug 2023 18:03:58 +0100 Subject: [PATCH 04/10] updated formatting --- d3rlpy/algos/qlearning/torch/cql_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 4096b3e7..19686191 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -238,7 +238,7 @@ def compute_loss( batch.observations, batch.actions.long() ) cql_loss = self._alpha * conservative_loss - return loss + cql_loss, cql_loss + return loss + cql_loss, cql_loss def _compute_conservative_loss( self, obs_t: torch.Tensor, act_t: torch.Tensor From a46d73f60d907a6c11594e43bb72d455971d3f51 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Thu, 10 Aug 2023 18:15:42 +0100 Subject: [PATCH 05/10] update gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 348562a3..3b61fe79 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ docs/d3rlpy*.rst docs/modules.rst docs/references/generated coverage.xml -.coverage +.coverage* .mypy_cache .ipynb_checkpoints build From e87cb94806b69f790dd98f79fc1d6495614454c1 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Tue, 29 Aug 2023 11:22:48 +0100 Subject: [PATCH 06/10] corrected formatting --- d3rlpy/algos/qlearning/torch/cql_impl.py | 23 ++++++++++++++--------- d3rlpy/algos/qlearning/torch/ddpg_impl.py | 12 ++++++++---- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 31ba9606..15525bed 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -15,14 +15,17 @@ build_squashed_gaussian_distribution, ) from ....torch_utility import TorchMiniBatch +from .ddpg_impl import DDPGCriticLoss from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules from .sac_impl import SACImpl, SACModules -from .ddpg_impl import DDPGCriticLoss __all__ = [ - "CQLImpl", "DiscreteCQLImpl", "CQLModules", "DiscreteCQLLoss", - "CQLLoss" - ] + "CQLImpl", + "DiscreteCQLImpl", + "CQLModules", + "DiscreteCQLLoss", + "CQLLoss", +] @dataclasses.dataclass(frozen=True) @@ -82,9 +85,10 @@ def compute_critic_loss( batch.observations, batch.actions, batch.next_observations ) return CQLLoss( - loss=loss+conservative_loss, td_loss=loss, - conservative_loss=conservative_loss - ) + loss=loss + conservative_loss, + td_loss=loss, + conservative_loss=conservative_loss, + ) def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.critic_optim.zero_grad() @@ -300,6 +304,7 @@ def compute_loss( ) loss = td_loss + self._alpha * conservative_loss return DiscreteCQLLoss( - loss=loss, td_loss=td_loss, - conservative_loss=self._alpha * conservative_loss + loss=loss, + td_loss=td_loss, + conservative_loss=self._alpha * conservative_loss, ) diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index ce442908..26a09d7e 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -14,9 +14,12 @@ from .utility import ContinuousQFunctionMixin __all__ = [ - "DDPGImpl", "DDPGBaseImpl", "DDPGBaseModules", "DDPGModules", - "DDPGCriticLoss" - ] + "DDPGImpl", + "DDPGBaseImpl", + "DDPGBaseModules", + "DDPGModules", + "DDPGCriticLoss", +] @dataclasses.dataclass(frozen=True) @@ -27,6 +30,7 @@ class DDPGBaseModules(Modules): actor_optim: Optimizer critic_optim: Optimizer + @dataclasses.dataclass(frozen=True) class DDPGCriticLoss: loss: torch.Tensor @@ -75,7 +79,7 @@ def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.critic_optim.step() return asdict_as_float(loss) - + def compute_critic_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor ) -> DDPGCriticLoss: From 3ebd8cb9f951659608cca0feb539ddde03e2dc07 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Tue, 29 Aug 2023 11:55:35 +0100 Subject: [PATCH 07/10] first draft of parameter reset callback --- d3rlpy/models/torch/callbacks.py | 47 ++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 d3rlpy/models/torch/callbacks.py diff --git a/d3rlpy/models/torch/callbacks.py b/d3rlpy/models/torch/callbacks.py new file mode 100644 index 00000000..1834e981 --- /dev/null +++ b/d3rlpy/models/torch/callbacks.py @@ -0,0 +1,47 @@ +from abc import ABCMeta, abstractmethod +from typing import Sequence + +from ...algos import QLearningAlgoBase, QLearningAlgoImplBase +from ...constants import IMPL_NOT_INITIALIZED_ERROR + + +class QLearningCallback(metaclass=ABCMeta): + @abstractmethod + def __call__(self, algo: QLearningAlgoBase, epoch: int, total_step: int): + pass + + +class ParameterReset(QLearningCallback): + def __init__(self, replay_ratio: int, layer_reset:Sequence[bool], + algo:QLearningAlgoBase=None) -> None: + self._replay_ratio = replay_ratio + self._layer_reset = layer_reset + self._check = False + if algo is not None: + self._check_layer_resets(algo=algo) + + + def _check_layer_resets(self, algo:QLearningAlgoBase): + assert algo._impl is not None, IMPL_NOT_INITIALIZED_ERROR + assert isinstance(algo._impl, QLearningAlgoImplBase) + valid_layers = [ + hasattr(layer, 'reset_parameters') for lr, layer in zip( + self._layer_reset, algo._impl.q_function) + if lr + ] + self._check = all(valid_layers) + if not self._check: + raise ValueError( + "Some layer do not contain resettable parameters" + ) + + def __call__(self, algo: QLearningAlgoBase, epoch: int, total_step: int): + if not self._check: + self._check_layer_resets(algo=algo) + assert isinstance(algo._impl, QLearningAlgoImplBase) + if epoch % self._replay_ratio == 0: + for lr, layer in enumerate( + zip(self._layer_reset, algo._impl.q_function) + ): + if lr: + layer.reset_parameters() \ No newline at end of file From 2846d2040931c000f624ec95bc80ed7a68eb20e4 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Tue, 29 Aug 2023 11:57:27 +0100 Subject: [PATCH 08/10] corrected formatting --- d3rlpy/algos/qlearning/torch/cql_impl.py | 23 ++++++++++++++--------- d3rlpy/algos/qlearning/torch/ddpg_impl.py | 12 ++++++++---- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 31ba9606..15525bed 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -15,14 +15,17 @@ build_squashed_gaussian_distribution, ) from ....torch_utility import TorchMiniBatch +from .ddpg_impl import DDPGCriticLoss from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules from .sac_impl import SACImpl, SACModules -from .ddpg_impl import DDPGCriticLoss __all__ = [ - "CQLImpl", "DiscreteCQLImpl", "CQLModules", "DiscreteCQLLoss", - "CQLLoss" - ] + "CQLImpl", + "DiscreteCQLImpl", + "CQLModules", + "DiscreteCQLLoss", + "CQLLoss", +] @dataclasses.dataclass(frozen=True) @@ -82,9 +85,10 @@ def compute_critic_loss( batch.observations, batch.actions, batch.next_observations ) return CQLLoss( - loss=loss+conservative_loss, td_loss=loss, - conservative_loss=conservative_loss - ) + loss=loss + conservative_loss, + td_loss=loss, + conservative_loss=conservative_loss, + ) def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.critic_optim.zero_grad() @@ -300,6 +304,7 @@ def compute_loss( ) loss = td_loss + self._alpha * conservative_loss return DiscreteCQLLoss( - loss=loss, td_loss=td_loss, - conservative_loss=self._alpha * conservative_loss + loss=loss, + td_loss=td_loss, + conservative_loss=self._alpha * conservative_loss, ) diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index ce442908..26a09d7e 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -14,9 +14,12 @@ from .utility import ContinuousQFunctionMixin __all__ = [ - "DDPGImpl", "DDPGBaseImpl", "DDPGBaseModules", "DDPGModules", - "DDPGCriticLoss" - ] + "DDPGImpl", + "DDPGBaseImpl", + "DDPGBaseModules", + "DDPGModules", + "DDPGCriticLoss", +] @dataclasses.dataclass(frozen=True) @@ -27,6 +30,7 @@ class DDPGBaseModules(Modules): actor_optim: Optimizer critic_optim: Optimizer + @dataclasses.dataclass(frozen=True) class DDPGCriticLoss: loss: torch.Tensor @@ -75,7 +79,7 @@ def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.critic_optim.step() return asdict_as_float(loss) - + def compute_critic_loss( self, batch: TorchMiniBatch, q_tpn: torch.Tensor ) -> DDPGCriticLoss: From 7017e6fc9a5e93d6bdb70d07a5d765dcd563961e Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Tue, 29 Aug 2023 13:33:56 +0100 Subject: [PATCH 09/10] first go at tests for param reset callback --- .../qlearning}/torch/callbacks.py | 14 ++- tests/algos/qlearning/test_callbacks.py | 113 ++++++++++++++++++ 2 files changed, 122 insertions(+), 5 deletions(-) rename d3rlpy/{models => algos/qlearning}/torch/callbacks.py (82%) create mode 100644 tests/algos/qlearning/test_callbacks.py diff --git a/d3rlpy/models/torch/callbacks.py b/d3rlpy/algos/qlearning/torch/callbacks.py similarity index 82% rename from d3rlpy/models/torch/callbacks.py rename to d3rlpy/algos/qlearning/torch/callbacks.py index 1834e981..ccd128bf 100644 --- a/d3rlpy/models/torch/callbacks.py +++ b/d3rlpy/algos/qlearning/torch/callbacks.py @@ -1,9 +1,12 @@ from abc import ABCMeta, abstractmethod from typing import Sequence -from ...algos import QLearningAlgoBase, QLearningAlgoImplBase -from ...constants import IMPL_NOT_INITIALIZED_ERROR +from ... import QLearningAlgoBase, QLearningAlgoImplBase +from ....constants import IMPL_NOT_INITIALIZED_ERROR +__all__ = [ + "ParameterReset" +] class QLearningCallback(metaclass=ABCMeta): @abstractmethod @@ -24,6 +27,9 @@ def __init__(self, replay_ratio: int, layer_reset:Sequence[bool], def _check_layer_resets(self, algo:QLearningAlgoBase): assert algo._impl is not None, IMPL_NOT_INITIALIZED_ERROR assert isinstance(algo._impl, QLearningAlgoImplBase) + + if len(self._layer_reset) != len(algo._impl.q_function): + raise ValueError valid_layers = [ hasattr(layer, 'reset_parameters') for lr, layer in zip( self._layer_reset, algo._impl.q_function) @@ -40,8 +46,6 @@ def __call__(self, algo: QLearningAlgoBase, epoch: int, total_step: int): self._check_layer_resets(algo=algo) assert isinstance(algo._impl, QLearningAlgoImplBase) if epoch % self._replay_ratio == 0: - for lr, layer in enumerate( - zip(self._layer_reset, algo._impl.q_function) - ): + for lr, layer in zip(self._layer_reset, algo._impl.q_function): if lr: layer.reset_parameters() \ No newline at end of file diff --git a/tests/algos/qlearning/test_callbacks.py b/tests/algos/qlearning/test_callbacks.py new file mode 100644 index 00000000..7bfec3fa --- /dev/null +++ b/tests/algos/qlearning/test_callbacks.py @@ -0,0 +1,113 @@ +import pytest +from typing import Any, Sequence, List, Union +from unittest.mock import MagicMock, Mock +from d3rlpy.dataset import Shape + +from d3rlpy.algos.qlearning.torch.callbacks import ParameterReset +from d3rlpy.algos import QLearningAlgoBase, QLearningAlgoImplBase +from d3rlpy.torch_utility import Modules +import torch + +from ...test_torch_utility import DummyModules + + +class LayerHasResetMock: + + def reset_parameters(self): + return True + +class LayerNoResetMock: + pass + +fc = torch.nn.Linear(100, 100) +optim = torch.optim.Adam(fc.parameters()) +modules = DummyModules(fc=fc, optim=optim) + +class ImplMock(MagicMock): + + def __init__( + self, q_funcs:List[Union[LayerHasResetMock, LayerNoResetMock]] + ) -> None: + super().__init__(spec=QLearningAlgoImplBase) + self.q_function = q_funcs + + +class QLearningAlgoBaseMock(MagicMock): + + def __init__(self, spec, layer_setup:Sequence[bool]) -> None: + super().__init__(spec=spec) + q_funcs = [] + for i in layer_setup: + if i: + q_funcs.append(LayerHasResetMock()) + else: + q_funcs.append(LayerNoResetMock()) + self._impl = ImplMock(q_funcs=q_funcs) + + + +def test_check_layer_resets(): + algo = QLearningAlgoBaseMock(spec=QLearningAlgoBase, + layer_setup=[True, True, False]) + replay_ratio = 2 + layer_reset_valid = [True, True, False] + pr = ParameterReset( + replay_ratio=replay_ratio, + layer_reset=layer_reset_valid, + algo=algo + ) + assert pr._check is True + + layer_reset_invalid = [True, True, True] + try: + pr = ParameterReset( + replay_ratio=replay_ratio, + layer_reset=layer_reset_invalid, + algo=algo + ) + raise Exception + except ValueError as e: + assert True + + layer_reset_long = [True, True, True, False] + try: + pr = ParameterReset( + replay_ratio=replay_ratio, + layer_reset=layer_reset_long, + algo=algo + ) + raise Exception + except ValueError as e: + assert True + + layer_reset_shrt = [True, True] + try: + pr = ParameterReset( + replay_ratio=replay_ratio, + layer_reset=layer_reset_shrt, + algo=algo + ) + raise Exception + except ValueError as e: + assert True + + +def test_call(): + algo = QLearningAlgoBaseMock(spec=QLearningAlgoBase, + layer_setup=[True, True, False]) + replay_ratio = 2 + layer_reset_valid = [True, True, False] + pr = ParameterReset( + replay_ratio=replay_ratio, + layer_reset=layer_reset_valid, + algo=algo + ) + pr(algo=algo, epoch=1, total_step=100) + pr(algo=algo, epoch=2, total_step=100) + + pr = ParameterReset( + replay_ratio=replay_ratio, + layer_reset=layer_reset_valid, + ) + pr(algo=algo, epoch=1, total_step=100) + pr(algo=algo, epoch=2, total_step=100) \ No newline at end of file From d793c6f9d2894e0650b26919458bf559dffcfe9b Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Tue, 29 Aug 2023 18:23:18 +0100 Subject: [PATCH 10/10] fixed issues in call method --- d3rlpy/algos/qlearning/torch/callbacks.py | 52 ++++++++++++++++------- 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/d3rlpy/algos/qlearning/torch/callbacks.py b/d3rlpy/algos/qlearning/torch/callbacks.py index ccd128bf..09fe0dff 100644 --- a/d3rlpy/algos/qlearning/torch/callbacks.py +++ b/d3rlpy/algos/qlearning/torch/callbacks.py @@ -1,5 +1,6 @@ from abc import ABCMeta, abstractmethod -from typing import Sequence +from typing import Sequence, List +import torch.nn as nn from ... import QLearningAlgoBase, QLearningAlgoImplBase from ....constants import IMPL_NOT_INITIALIZED_ERROR @@ -15,27 +16,45 @@ def __call__(self, algo: QLearningAlgoBase, epoch: int, total_step: int): class ParameterReset(QLearningCallback): - def __init__(self, replay_ratio: int, layer_reset:Sequence[bool], - algo:QLearningAlgoBase=None) -> None: + def __init__(self, replay_ratio: int, encoder_reset:Sequence[bool], + output_reset:bool, algo:QLearningAlgoBase=None) -> None: self._replay_ratio = replay_ratio - self._layer_reset = layer_reset + self._encoder_reset = encoder_reset + self._output_reset = output_reset self._check = False if algo is not None: self._check_layer_resets(algo=algo) + def _get_layers(self, q_func:nn.ModuleList)->List[nn.Module]: + all_modules = {nm:module for (nm, module) in q_func.named_modules()} + q_func_layers = [ + *all_modules["_encoder._layers"], + all_modules["_fc"] + ] + return q_func_layers + def _check_layer_resets(self, algo:QLearningAlgoBase): assert algo._impl is not None, IMPL_NOT_INITIALIZED_ERROR assert isinstance(algo._impl, QLearningAlgoImplBase) - if len(self._layer_reset) != len(algo._impl.q_function): - raise ValueError - valid_layers = [ - hasattr(layer, 'reset_parameters') for lr, layer in zip( - self._layer_reset, algo._impl.q_function) - if lr - ] - self._check = all(valid_layers) + all_valid_layers = [] + for q_func in algo._impl.q_function: + q_func_layers = self._get_layers(q_func) + if len(self._encoder_reset) + 1 != len(q_func_layers): + raise ValueError( + f""" + q_function layers: {q_func_layers}; + specified encoder layers: {self._encoder_reset} + """ + ) + valid_layers = [ + hasattr(layer, 'reset_parameters') for lr, layer in zip( + self._encoder_reset, q_func_layers) + if lr + ] + all_valid_layers.append(all(valid_layers)) + self._check = all(all_valid_layers) if not self._check: raise ValueError( "Some layer do not contain resettable parameters" @@ -46,6 +65,9 @@ def __call__(self, algo: QLearningAlgoBase, epoch: int, total_step: int): self._check_layer_resets(algo=algo) assert isinstance(algo._impl, QLearningAlgoImplBase) if epoch % self._replay_ratio == 0: - for lr, layer in zip(self._layer_reset, algo._impl.q_function): - if lr: - layer.reset_parameters() \ No newline at end of file + reset_lst = [*self._encoder_reset, self._output_reset] + for q_func in algo._impl.q_function: + q_func_layers = self._get_layers(q_func) + for lr, layer in zip(reset_lst, q_func_layers): + if lr: + layer.reset_parameters() \ No newline at end of file