diff --git a/d3rlpy/algos/qlearning/ddpg.py b/d3rlpy/algos/qlearning/ddpg.py index abb40896..32de3dba 100644 --- a/d3rlpy/algos/qlearning/ddpg.py +++ b/d3rlpy/algos/qlearning/ddpg.py @@ -11,7 +11,8 @@ from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field from ...types import Shape from .base import QLearningAlgoBase -from .torch.ddpg_impl import DDPGImpl, DDPGModules +from .torch.ddpg_impl import DDPGActionSampler, DDPGValuePredictor, DDPGCriticLossFn, DDPGActorLossFn, DDPGUpdater, DDPGModules +from .functional import FunctionalQLearningAlgoImplBase __all__ = ["DDPGConfig", "DDPG"] @@ -93,7 +94,7 @@ def get_type() -> str: return "ddpg" -class DDPG(QLearningAlgoBase[DDPGImpl, DDPGConfig]): +class DDPG(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, DDPGConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: @@ -150,15 +151,41 @@ def inner_create_impl( critic_optim=critic_optim, ) - self._impl = DDPGImpl( + updater = DDPGUpdater( + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + policy=policy, + targ_policy=targ_policy, + critic_optim=critic_optim, + actor_optim=actor_optim, + critic_loss_fn=DDPGCriticLossFn( + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, + targ_policy=targ_policy, + gamma=self._config.gamma, + ), + actor_loss_fn=DDPGActorLossFn( + q_func_forwarder=q_func_forwarder, + policy=policy, + ), + tau=self._config.tau, + compiled=self.compiled, + ) + action_sampler = DDPGActionSampler(policy) + value_predictor = DDPGValuePredictor(q_func_forwarder) + + self._impl = FunctionalQLearningAlgoImplBase( observation_shape=observation_shape, action_size=action_size, modules=modules, - q_func_forwarder=q_func_forwarder, - targ_q_func_forwarder=targ_q_func_forwarder, - gamma=self._config.gamma, - tau=self._config.tau, - compiled=self.compiled, + updater=updater, + exploit_action_sampler=action_sampler, + explore_action_sampler=action_sampler, + value_predictor=value_predictor, + q_function=q_funcs, + q_function_optim=critic_optim.optim, + policy=policy, + policy_optim=actor_optim.optim, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/sac.py b/d3rlpy/algos/qlearning/sac.py index ef66c40a..4ca3a5dd 100644 --- a/d3rlpy/algos/qlearning/sac.py +++ b/d3rlpy/algos/qlearning/sac.py @@ -15,12 +15,17 @@ from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field from ...types import Shape from .base import QLearningAlgoBase +from .torch.ddpg_impl import DDPGActionSampler, DDPGValuePredictor from .torch.sac_impl import ( DiscreteSACImpl, DiscreteSACModules, - SACImpl, SACModules, + SACActionSampler, + SACCriticLossFn, + SACActorLossFn, + SACUpdater, ) +from .functional import FunctionalQLearningAlgoImplBase __all__ = ["SACConfig", "SAC", "DiscreteSACConfig", "DiscreteSAC"] @@ -122,7 +127,7 @@ def get_type() -> str: return "sac" -class SAC(QLearningAlgoBase[SACImpl, SACConfig]): +class SAC(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, SACConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: @@ -187,15 +192,44 @@ def inner_create_impl( temp_optim=temp_optim, ) - self._impl = SACImpl( + updater = SACUpdater( + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + critic_optim=critic_optim, + actor_optim=actor_optim, + critic_loss_fn=SACCriticLossFn( + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, + policy=policy, + log_temp=log_temp, + gamma=self._config.gamma, + ), + actor_loss_fn=SACActorLossFn( + q_func_forwarder=q_func_forwarder, + policy=policy, + log_temp=log_temp, + temp_optim=temp_optim, + action_size=action_size, + ), + tau=self._config.tau, + compiled=self.compiled, + ) + exploit_action_sampler = DDPGActionSampler(policy) + explore_action_sampler = SACActionSampler(policy) + value_predictor = DDPGValuePredictor(q_func_forwarder) + + self._impl = FunctionalQLearningAlgoImplBase( observation_shape=observation_shape, action_size=action_size, modules=modules, - q_func_forwarder=q_func_forwarder, - targ_q_func_forwarder=targ_q_func_forwarder, - gamma=self._config.gamma, - tau=self._config.tau, - compiled=self.compiled, + updater=updater, + exploit_action_sampler=exploit_action_sampler, + explore_action_sampler=explore_action_sampler, + value_predictor=value_predictor, + q_function=q_funcs, + q_function_optim=critic_optim.optim, + policy=policy, + policy_optim=actor_optim.optim, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3.py b/d3rlpy/algos/qlearning/td3.py index e61bf7b4..b913bfb5 100644 --- a/d3rlpy/algos/qlearning/td3.py +++ b/d3rlpy/algos/qlearning/td3.py @@ -11,8 +11,9 @@ from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field from ...types import Shape from .base import QLearningAlgoBase -from .torch.ddpg_impl import DDPGModules -from .torch.td3_impl import TD3Impl +from .functional import FunctionalQLearningAlgoImplBase +from .torch.ddpg_impl import DDPGModules, DDPGActionSampler, DDPGValuePredictor, DDPGActorLossFn +from .torch.td3_impl import TD3CriticLossFn, TD3Updater __all__ = ["TD3Config", "TD3"] @@ -102,7 +103,7 @@ def get_type() -> str: return "td3" -class TD3(QLearningAlgoBase[TD3Impl, TD3Config]): +class TD3(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, TD3Config]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: @@ -159,18 +160,44 @@ def inner_create_impl( critic_optim=critic_optim, ) - self._impl = TD3Impl( - observation_shape=observation_shape, - action_size=action_size, - modules=modules, - q_func_forwarder=q_func_forwarder, - targ_q_func_forwarder=targ_q_func_forwarder, - gamma=self._config.gamma, + updater = TD3Updater( + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + policy=policy, + targ_policy=targ_policy, + critic_optim=critic_optim, + actor_optim=actor_optim, + critic_loss_fn=TD3CriticLossFn( + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, + targ_policy=targ_policy, + gamma=self._config.gamma, + target_smoothing_sigma=self._config.target_smoothing_sigma, + target_smoothing_clip=self._config.target_smoothing_clip, + ), + actor_loss_fn=DDPGActorLossFn( + q_func_forwarder=q_func_forwarder, + policy=policy, + ), tau=self._config.tau, - target_smoothing_sigma=self._config.target_smoothing_sigma, - target_smoothing_clip=self._config.target_smoothing_clip, update_actor_interval=self._config.update_actor_interval, compiled=self.compiled, + ) + action_sampler = DDPGActionSampler(policy) + value_predictor = DDPGValuePredictor(q_func_forwarder) + + self._impl = FunctionalQLearningAlgoImplBase( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + updater=updater, + exploit_action_sampler=action_sampler, + explore_action_sampler=action_sampler, + value_predictor=value_predictor, + q_function=q_funcs, + q_function_optim=critic_optim.optim, + policy=policy, + policy_optim=actor_optim.optim, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3_plus_bc.py b/d3rlpy/algos/qlearning/td3_plus_bc.py index 62758a40..b9a59880 100644 --- a/d3rlpy/algos/qlearning/td3_plus_bc.py +++ b/d3rlpy/algos/qlearning/td3_plus_bc.py @@ -11,8 +11,10 @@ from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field from ...types import Shape from .base import QLearningAlgoBase -from .torch.ddpg_impl import DDPGModules -from .torch.td3_plus_bc_impl import TD3PlusBCImpl +from .functional import FunctionalQLearningAlgoImplBase +from .torch.ddpg_impl import DDPGModules, DDPGValuePredictor, DDPGActionSampler +from .torch.td3_impl import TD3CriticLossFn, TD3Updater +from .torch.td3_plus_bc_impl import TD3PlusBCActorLossFn __all__ = ["TD3PlusBCConfig", "TD3PlusBC"] @@ -94,7 +96,7 @@ def get_type() -> str: return "td3_plus_bc" -class TD3PlusBC(QLearningAlgoBase[TD3PlusBCImpl, TD3PlusBCConfig]): +class TD3PlusBC(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, TD3PlusBCConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: @@ -151,19 +153,45 @@ def inner_create_impl( critic_optim=critic_optim, ) - self._impl = TD3PlusBCImpl( - observation_shape=observation_shape, - action_size=action_size, - modules=modules, - q_func_forwarder=q_func_forwarder, - targ_q_func_forwarder=targ_q_func_forwarder, - gamma=self._config.gamma, + updater = TD3Updater( + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + policy=policy, + targ_policy=targ_policy, + critic_optim=critic_optim, + actor_optim=actor_optim, + critic_loss_fn=TD3CriticLossFn( + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, + targ_policy=targ_policy, + gamma=self._config.gamma, + target_smoothing_sigma=self._config.target_smoothing_sigma, + target_smoothing_clip=self._config.target_smoothing_clip, + ), + actor_loss_fn=TD3PlusBCActorLossFn( + q_func_forwarder=q_func_forwarder, + policy=policy, + alpha=self._config.alpha, + ), tau=self._config.tau, - target_smoothing_sigma=self._config.target_smoothing_sigma, - target_smoothing_clip=self._config.target_smoothing_clip, - alpha=self._config.alpha, update_actor_interval=self._config.update_actor_interval, compiled=self.compiled, + ) + action_sampler = DDPGActionSampler(policy) + value_predictor = DDPGValuePredictor(q_func_forwarder) + + self._impl = FunctionalQLearningAlgoImplBase( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + updater=updater, + exploit_action_sampler=action_sampler, + explore_action_sampler=action_sampler, + value_predictor=value_predictor, + q_function=q_funcs, + q_function_optim=critic_optim.optim, + policy=policy, + policy_optim=actor_optim.optim, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index f92fc15f..fd4ba198 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -1,14 +1,11 @@ import dataclasses from abc import ABCMeta, abstractmethod -from typing import Callable import torch from torch import nn -from torch.optim import Optimizer from ....dataclass_utils import asdict_as_float from ....models.torch import ( - ActionOutput, ContinuousEnsembleQFunctionForwarder, Policy, ) @@ -17,20 +14,24 @@ CudaGraphWrapper, Modules, TorchMiniBatch, - hard_sync, soft_sync, ) -from ....types import Shape, TorchObservation -from ..base import QLearningAlgoImplBase -from .utility import ContinuousQFunctionMixin +from ....types import TorchObservation +from ..functional import ActionSampler, Updater, ValuePredictor __all__ = [ - "DDPGImpl", - "DDPGBaseImpl", "DDPGBaseModules", "DDPGModules", "DDPGBaseActorLoss", "DDPGBaseCriticLoss", + "DDPGBaseCriticLossFn", + "DDPGBaseActorLossFn", + "DDPGCriticLossFn", + "DDPGActorLossFn", + "DDPGActionSampler", + "DDPGValuePredictor", + "DDPGBaseUpdater", + "DDPGUpdater", ] @@ -53,195 +54,180 @@ class DDPGBaseCriticLoss: critic_loss: torch.Tensor -class DDPGBaseImpl( - ContinuousQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta -): - _modules: DDPGBaseModules - _compute_critic_grad: Callable[[TorchMiniBatch], DDPGBaseCriticLoss] - _compute_actor_grad: Callable[[TorchMiniBatch], DDPGBaseActorLoss] - _gamma: float - _tau: float - _q_func_forwarder: ContinuousEnsembleQFunctionForwarder - _targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder - +class DDPGBaseCriticLossFn(metaclass=ABCMeta): def __init__( self, - observation_shape: Shape, - action_size: int, - modules: DDPGBaseModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, gamma: float, - tau: float, - compiled: bool, - device: str, ): - super().__init__( - observation_shape=observation_shape, - action_size=action_size, - modules=modules, - device=device, - ) - self._gamma = gamma - self._tau = tau self._q_func_forwarder = q_func_forwarder self._targ_q_func_forwarder = targ_q_func_forwarder - self._compute_critic_grad = ( - CudaGraphWrapper(self.compute_critic_grad) - if compiled - else self.compute_critic_grad - ) - self._compute_actor_grad = ( - CudaGraphWrapper(self.compute_actor_grad) - if compiled - else self.compute_actor_grad - ) - hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) + self._gamma = gamma - def compute_critic_grad(self, batch: TorchMiniBatch) -> DDPGBaseCriticLoss: - self._modules.critic_optim.zero_grad() + def __call__(self, batch: TorchMiniBatch) -> DDPGBaseCriticLoss: q_tpn = self.compute_target(batch) - loss = self.compute_critic_loss(batch, q_tpn) - loss.critic_loss.backward() - return loss - - def update_critic(self, batch: TorchMiniBatch) -> dict[str, float]: - loss = self._compute_critic_grad(batch) - self._modules.critic_optim.step() - return asdict_as_float(loss) - - def compute_critic_loss( - self, batch: TorchMiniBatch, q_tpn: torch.Tensor - ) -> DDPGBaseCriticLoss: loss = self._q_func_forwarder.compute_error( observations=batch.observations, - actions=batch.actions, + actions=batch.actions.long(), rewards=batch.rewards, target=q_tpn, terminals=batch.terminals, gamma=self._gamma**batch.intervals, ) - return DDPGBaseCriticLoss(loss) - - def compute_actor_grad(self, batch: TorchMiniBatch) -> DDPGBaseActorLoss: - action = self._modules.policy(batch.observations) - self._modules.actor_optim.zero_grad() - loss = self.compute_actor_loss(batch, action) - loss.actor_loss.backward() - return loss - - def update_actor(self, batch: TorchMiniBatch) -> dict[str, float]: - # Q function should be inference mode for stability - self._modules.q_funcs.eval() - loss = self._compute_actor_grad(batch) - self._modules.actor_optim.step() - return asdict_as_float(loss) - - def inner_update( - self, batch: TorchMiniBatch, grad_step: int - ) -> dict[str, float]: - metrics = {} - metrics.update(self.update_critic(batch)) - metrics.update(self.update_actor(batch)) - self.update_critic_target() - return metrics - - @abstractmethod - def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput - ) -> DDPGBaseActorLoss: - pass + return DDPGBaseCriticLoss(critic_loss=loss) @abstractmethod def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: - pass + raise NotImplementedError - def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: - return self._modules.policy(x).squashed_mu +class DDPGBaseActorLossFn(metaclass=ABCMeta): @abstractmethod - def inner_sample_action(self, x: TorchObservation) -> torch.Tensor: - pass - - def update_critic_target(self) -> None: - soft_sync(self._modules.targ_q_funcs, self._modules.q_funcs, self._tau) - - @property - def policy(self) -> Policy: - return self._modules.policy - - @property - def policy_optim(self) -> Optimizer: - return self._modules.actor_optim.optim - - @property - def q_function(self) -> nn.ModuleList: - return self._modules.q_funcs - - @property - def q_function_optim(self) -> Optimizer: - return self._modules.critic_optim.optim - - -@dataclasses.dataclass(frozen=True) -class DDPGModules(DDPGBaseModules): - targ_policy: Policy - + def __call__(self, batch: TorchMiniBatch) -> DDPGBaseActorLoss: + raise NotImplementedError -class DDPGImpl(DDPGBaseImpl): - _modules: DDPGModules +class DDPGCriticLossFn(DDPGBaseCriticLossFn): def __init__( self, - observation_shape: Shape, - action_size: int, - modules: DDPGModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_policy: Policy, gamma: float, - tau: float, - compiled: bool, - device: str, ): super().__init__( - observation_shape=observation_shape, - action_size=action_size, - modules=modules, q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, - tau=tau, - compiled=compiled, - device=device, ) - hard_sync(self._modules.targ_policy, self._modules.policy) - - def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput - ) -> DDPGBaseActorLoss: - q_t = self._q_func_forwarder.compute_expected_q( - batch.observations, action.squashed_mu, "none" - )[0] - return DDPGBaseActorLoss(-q_t.mean()) + self._targ_policy = targ_policy def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): - action = self._modules.targ_policy(batch.next_observations) + action = self._targ_policy(batch.next_observations) return self._targ_q_func_forwarder.compute_target( batch.next_observations, action.squashed_mu.clamp(-1.0, 1.0), reduction="min", ) - def inner_sample_action(self, x: TorchObservation) -> torch.Tensor: - return self.inner_predict_best_action(x) - def update_actor_target(self) -> None: - soft_sync(self._modules.targ_policy, self._modules.policy, self._tau) +class DDPGActorLossFn(DDPGBaseActorLossFn): + def __init__(self, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, policy: Policy): + self._q_func_forwarder = q_func_forwarder + self._policy = policy + + def __call__(self, batch: TorchMiniBatch) -> DDPGBaseActorLoss: + action = self._policy(batch.observations) + q_t = self._q_func_forwarder.compute_expected_q( + batch.observations, action.squashed_mu, "none" + )[0] + return DDPGBaseActorLoss(-q_t.mean()) + + +class DDPGBaseUpdater(Updater): + def __init__( + self, + critic_optim: OptimizerWrapper, + actor_optim: OptimizerWrapper, + critic_loss_fn: DDPGBaseCriticLossFn, + actor_loss_fn: DDPGBaseActorLossFn, + compiled: bool, + ): + self._critic_optim = critic_optim + self._actor_optim = actor_optim + self._critic_loss_fn = critic_loss_fn + self._actor_loss_fn = actor_loss_fn + self._compute_critic_grad = CudaGraphWrapper(self.compute_critic_grad) if compiled else self.compute_critic_grad + self._compute_actor_grad = CudaGraphWrapper(self.compute_actor_grad) if compiled else self.compute_actor_grad + + def compute_critic_grad(self, batch: TorchMiniBatch) -> DDPGBaseCriticLoss: + self._critic_optim.zero_grad() + loss = self._critic_loss_fn(batch) + loss.critic_loss.backward() + return loss + + def compute_actor_grad(self, batch: TorchMiniBatch) -> DDPGBaseActorLoss: + self._actor_optim.zero_grad() + loss = self._actor_loss_fn(batch) + loss.actor_loss.backward() + return loss + + def __call__(self, batch: TorchMiniBatch, grad_step: int) -> dict[str, float]: + metrics = {} + + # update critic + critic_loss = self._compute_critic_grad(batch) + self._critic_optim.step() + metrics.update(asdict_as_float(critic_loss)) + + # update actor + actor_loss = self._compute_actor_grad(batch) + self._actor_optim.step() + metrics.update(asdict_as_float(actor_loss)) + + # update target networks + self.update_target() - def inner_update( - self, batch: TorchMiniBatch, grad_step: int - ) -> dict[str, float]: - metrics = super().inner_update(batch, grad_step) - self.update_actor_target() return metrics + + @abstractmethod + def update_target(self) -> None: + raise NotImplementedError + + +@dataclasses.dataclass(frozen=True) +class DDPGModules(DDPGBaseModules): + targ_policy: Policy + + +class DDPGActionSampler(ActionSampler): + def __init__(self, policy: Policy): + self._policy = policy + + def __call__(self, x: TorchObservation) -> torch.Tensor: + action = self._policy(x) + return action.squashed_mu + + +class DDPGValuePredictor(ValuePredictor): + def __init__(self, q_func_forwarder: ContinuousEnsembleQFunctionForwarder): + self._q_func_forwarder = q_func_forwarder + + def __call__(self, x: TorchObservation, action: torch.Tensor) -> torch.Tensor: + return self._q_func_forwarder.compute_expected_q( + x, action, reduction="mean" + ).reshape(-1) + + +class DDPGUpdater(DDPGBaseUpdater): + def __init__( + self, + q_funcs: nn.ModuleList, + targ_q_funcs: nn.ModuleList, + policy: Policy, + targ_policy: Policy, + critic_optim: OptimizerWrapper, + actor_optim: OptimizerWrapper, + critic_loss_fn: DDPGBaseCriticLossFn, + actor_loss_fn: DDPGBaseActorLossFn, + tau: float, + compiled: bool, + ): + super().__init__( + critic_optim=critic_optim, + actor_optim=actor_optim, + critic_loss_fn=critic_loss_fn, + actor_loss_fn=actor_loss_fn, + compiled=compiled, + ) + self._q_funcs = q_funcs + self._targ_q_funcs = targ_q_funcs + self._policy = policy + self._targ_policy = targ_policy + self._tau = tau + + def update_target(self) -> None: + soft_sync(self._targ_q_funcs, self._q_funcs, self._tau) + soft_sync(self._targ_policy, self._policy, self._tau) diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index e2c9a7ed..47ffd5f7 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -25,14 +25,19 @@ Modules, TorchMiniBatch, hard_sync, + soft_sync, ) from ....types import Shape, TorchObservation from ..base import QLearningAlgoImplBase -from .ddpg_impl import DDPGBaseActorLoss, DDPGBaseImpl, DDPGBaseModules +from .ddpg_impl import DDPGBaseActorLoss, DDPGBaseModules, DDPGBaseCriticLossFn, DDPGBaseActorLossFn, DDPGBaseUpdater from .utility import DiscreteQFunctionMixin +from ..functional import ActionSampler __all__ = [ - "SACImpl", + "SACActionSampler", + "SACCriticLossFn", + "SACActorLossFn", + "SACUpdater", "DiscreteSACImpl", "SACModules", "DiscreteSACModules", @@ -53,83 +58,120 @@ class SACActorLoss(DDPGBaseActorLoss): temp_loss: torch.Tensor -class SACImpl(DDPGBaseImpl): - _modules: SACModules +class SACActionSampler(ActionSampler): + def __init__(self, policy: Policy): + self._policy = policy + def __call__(self, x: TorchObservation) -> torch.Tensor: + dist = build_squashed_gaussian_distribution(self._policy(x)) + return dist.sample() + + +class SACCriticLossFn(DDPGBaseCriticLossFn): def __init__( self, - observation_shape: Shape, - action_size: int, - modules: SACModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + policy: Policy, + log_temp: Parameter, gamma: float, - tau: float, - compiled: bool, - device: str, ): super().__init__( - observation_shape=observation_shape, - action_size=action_size, - modules=modules, q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, - tau=tau, - compiled=compiled, - device=device, ) + self._policy = policy + self._log_temp = log_temp - def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput - ) -> SACActorLoss: + def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: + with torch.no_grad(): + dist = build_squashed_gaussian_distribution( + self._policy(batch.next_observations) + ) + action, log_prob = dist.sample_with_log_prob() + entropy = get_parameter(self._log_temp).exp() * log_prob + target = self._targ_q_func_forwarder.compute_target( + batch.next_observations, + action, + reduction="min", + ) + return target - entropy + + +class SACActorLossFn(DDPGBaseActorLossFn): + def __init__( + self, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + policy: Policy, + log_temp: Parameter, + temp_optim: Optional[OptimizerWrapper], + action_size: int, + ): + self._q_func_forwarder = q_func_forwarder + self._policy = policy + self._log_temp = log_temp + self._temp_optim = temp_optim + self._action_size = action_size + + def update_temp(self, log_prob: torch.Tensor) -> torch.Tensor: + assert self._temp_optim + self._temp_optim.zero_grad() + with torch.no_grad(): + targ_temp = log_prob - self._action_size + loss = -(get_parameter(self._log_temp).exp() * targ_temp).mean() + loss.backward() + self._temp_optim.step() + return loss + + def __call__(self, batch: TorchMiniBatch) -> SACActorLoss: + action = self._policy(batch.observations) dist = build_squashed_gaussian_distribution(action) sampled_action, log_prob = dist.sample_with_log_prob() - if self._modules.temp_optim: + if self._temp_optim: temp_loss = self.update_temp(log_prob) else: temp_loss = torch.tensor( 0.0, dtype=torch.float32, device=sampled_action.device ) - entropy = get_parameter(self._modules.log_temp).exp() * log_prob + entropy = get_parameter(self._log_temp).exp() * log_prob q_t = self._q_func_forwarder.compute_expected_q( batch.observations, sampled_action, "min" ) return SACActorLoss( actor_loss=(entropy - q_t).mean(), temp_loss=temp_loss, - temp=get_parameter(self._modules.log_temp).exp()[0][0], + temp=get_parameter(self._log_temp).exp()[0][0], ) - def update_temp(self, log_prob: torch.Tensor) -> torch.Tensor: - assert self._modules.temp_optim - self._modules.temp_optim.zero_grad() - with torch.no_grad(): - targ_temp = log_prob - self._action_size - loss = -(get_parameter(self._modules.log_temp).exp() * targ_temp).mean() - loss.backward() - self._modules.temp_optim.step() - return loss - def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: - with torch.no_grad(): - dist = build_squashed_gaussian_distribution( - self._modules.policy(batch.next_observations) - ) - action, log_prob = dist.sample_with_log_prob() - entropy = get_parameter(self._modules.log_temp).exp() * log_prob - target = self._targ_q_func_forwarder.compute_target( - batch.next_observations, - action, - reduction="min", - ) - return target - entropy +class SACUpdater(DDPGBaseUpdater): + def __init__( + self, + q_funcs: nn.ModuleList, + targ_q_funcs: nn.ModuleList, + critic_optim: OptimizerWrapper, + actor_optim: OptimizerWrapper, + critic_loss_fn: DDPGBaseCriticLossFn, + actor_loss_fn: DDPGBaseActorLossFn, + tau: float, + compiled: bool, + ): + super().__init__( + critic_optim=critic_optim, + actor_optim=actor_optim, + critic_loss_fn=critic_loss_fn, + actor_loss_fn=actor_loss_fn, + compiled=compiled, + ) + self._q_funcs = q_funcs + self._targ_q_funcs = targ_q_funcs + self._tau = tau - def inner_sample_action(self, x: TorchObservation) -> torch.Tensor: - dist = build_squashed_gaussian_distribution(self._modules.policy(x)) - return dist.sample() + def update_target(self) -> None: + soft_sync(self._targ_q_funcs, self._q_funcs, self._tau) @dataclasses.dataclass(frozen=True) diff --git a/d3rlpy/algos/qlearning/torch/td3_impl.py b/d3rlpy/algos/qlearning/torch/td3_impl.py index 73a43a0c..c6bbf992 100644 --- a/d3rlpy/algos/qlearning/torch/td3_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_impl.py @@ -1,52 +1,37 @@ - import torch +from torch import nn -from ....models.torch import ContinuousEnsembleQFunctionForwarder +from ....models.torch import ContinuousEnsembleQFunctionForwarder, Policy +from ....optimizers import OptimizerWrapper from ....torch_utility import TorchMiniBatch -from ....types import Shape -from .ddpg_impl import DDPGImpl, DDPGModules - -__all__ = ["TD3Impl"] +from ....dataclass_utils import asdict_as_float +from .ddpg_impl import DDPGCriticLossFn, DDPGUpdater, DDPGBaseCriticLossFn, DDPGBaseActorLossFn +__all__ = ["TD3CriticLossFn", "TD3Updater"] -class TD3Impl(DDPGImpl): - _target_smoothing_sigma: float - _target_smoothing_clip: float - _update_actor_interval: int +class TD3CriticLossFn(DDPGCriticLossFn): def __init__( self, - observation_shape: Shape, - action_size: int, - modules: DDPGModules, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_policy: Policy, gamma: float, - tau: float, target_smoothing_sigma: float, target_smoothing_clip: float, - update_actor_interval: int, - compiled: bool, - device: str, ): super().__init__( - observation_shape=observation_shape, - action_size=action_size, - modules=modules, q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, + targ_policy=targ_policy, gamma=gamma, - tau=tau, - compiled=compiled, - device=device, ) self._target_smoothing_sigma = target_smoothing_sigma self._target_smoothing_clip = target_smoothing_clip - self._update_actor_interval = update_actor_interval def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): - action = self._modules.targ_policy(batch.next_observations) + action = self._targ_policy(batch.next_observations) # smoothing target noise = torch.randn(action.mu.shape, device=batch.device) scaled_noise = self._target_smoothing_sigma * noise @@ -61,17 +46,52 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: reduction="min", ) - def inner_update( - self, batch: TorchMiniBatch, grad_step: int - ) -> dict[str, float]: + +class TD3Updater(DDPGUpdater): + def __init__( + self, + q_funcs: nn.ModuleList, + targ_q_funcs: nn.ModuleList, + policy: Policy, + targ_policy: Policy, + critic_optim: OptimizerWrapper, + actor_optim: OptimizerWrapper, + critic_loss_fn: DDPGBaseCriticLossFn, + actor_loss_fn: DDPGBaseActorLossFn, + tau: float, + update_actor_interval: int, + compiled: bool, + ): + super().__init__( + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + policy=policy, + targ_policy=targ_policy, + critic_optim=critic_optim, + actor_optim=actor_optim, + critic_loss_fn=critic_loss_fn, + actor_loss_fn=actor_loss_fn, + tau=tau, + compiled=compiled, + ) + self._update_actor_interval = update_actor_interval + + def __call__(self, batch: TorchMiniBatch, grad_step: int) -> dict[str, float]: metrics = {} - metrics.update(self.update_critic(batch)) + # update critic + critic_loss = self._compute_critic_grad(batch) + self._critic_optim.step() + metrics.update(asdict_as_float(critic_loss)) # delayed policy update if grad_step % self._update_actor_interval == 0: - metrics.update(self.update_actor(batch)) - self.update_critic_target() - self.update_actor_target() + # update actor + actor_loss = self._compute_actor_grad(batch) + self._actor_optim.step() + metrics.update(asdict_as_float(actor_loss)) + + # update target networks + self.update_target() return metrics diff --git a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py index 7614e8eb..4a082318 100644 --- a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py @@ -3,13 +3,11 @@ import torch -from ....models.torch import ActionOutput, ContinuousEnsembleQFunctionForwarder +from ....models.torch import ContinuousEnsembleQFunctionForwarder, Policy from ....torch_utility import TorchMiniBatch -from ....types import Shape -from .ddpg_impl import DDPGBaseActorLoss, DDPGModules -from .td3_impl import TD3Impl +from .ddpg_impl import DDPGBaseActorLoss, DDPGBaseActorLossFn -__all__ = ["TD3PlusBCImpl"] +__all__ = ["TD3PlusBCActorLoss", "TD3PlusBCActorLossFn"] @dataclasses.dataclass(frozen=True) @@ -17,44 +15,14 @@ class TD3PlusBCActorLoss(DDPGBaseActorLoss): bc_loss: torch.Tensor -class TD3PlusBCImpl(TD3Impl): - _alpha: float - - def __init__( - self, - observation_shape: Shape, - action_size: int, - modules: DDPGModules, - q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - gamma: float, - tau: float, - target_smoothing_sigma: float, - target_smoothing_clip: float, - alpha: float, - update_actor_interval: int, - compiled: bool, - device: str, - ): - super().__init__( - observation_shape=observation_shape, - action_size=action_size, - modules=modules, - q_func_forwarder=q_func_forwarder, - targ_q_func_forwarder=targ_q_func_forwarder, - gamma=gamma, - tau=tau, - target_smoothing_sigma=target_smoothing_sigma, - target_smoothing_clip=target_smoothing_clip, - update_actor_interval=update_actor_interval, - compiled=compiled, - device=device, - ) +class TD3PlusBCActorLossFn(DDPGBaseActorLossFn): + def __init__(self, q_func_forwarder: ContinuousEnsembleQFunctionForwarder, policy: Policy, alpha: float): + self._q_func_forwarder = q_func_forwarder + self._policy = policy self._alpha = alpha - def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput - ) -> TD3PlusBCActorLoss: + def __call__(self, batch: TorchMiniBatch) -> TD3PlusBCActorLoss: + action = self._policy(batch.observations) q_t = self._q_func_forwarder.compute_expected_q( batch.observations, action.squashed_mu, "none" )[0]