Skip to content

Commit

Permalink
Refactor DDPG, SAC and TD3
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 18, 2024
1 parent 9ca1912 commit 532d87c
Show file tree
Hide file tree
Showing 8 changed files with 446 additions and 314 deletions.
43 changes: 35 additions & 8 deletions d3rlpy/algos/qlearning/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.ddpg_impl import DDPGImpl, DDPGModules
from .torch.ddpg_impl import DDPGActionSampler, DDPGValuePredictor, DDPGCriticLossFn, DDPGActorLossFn, DDPGUpdater, DDPGModules
from .functional import FunctionalQLearningAlgoImplBase

__all__ = ["DDPGConfig", "DDPG"]

Expand Down Expand Up @@ -93,7 +94,7 @@ def get_type() -> str:
return "ddpg"


class DDPG(QLearningAlgoBase[DDPGImpl, DDPGConfig]):
class DDPG(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, DDPGConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
Expand Down Expand Up @@ -150,15 +151,41 @@ def inner_create_impl(
critic_optim=critic_optim,
)

self._impl = DDPGImpl(
updater = DDPGUpdater(
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
policy=policy,
targ_policy=targ_policy,
critic_optim=critic_optim,
actor_optim=actor_optim,
critic_loss_fn=DDPGCriticLossFn(
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
targ_policy=targ_policy,
gamma=self._config.gamma,
),
actor_loss_fn=DDPGActorLossFn(
q_func_forwarder=q_func_forwarder,
policy=policy,
),
tau=self._config.tau,
compiled=self.compiled,
)
action_sampler = DDPGActionSampler(policy)
value_predictor = DDPGValuePredictor(q_func_forwarder)

self._impl = FunctionalQLearningAlgoImplBase(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
compiled=self.compiled,
updater=updater,
exploit_action_sampler=action_sampler,
explore_action_sampler=action_sampler,
value_predictor=value_predictor,
q_function=q_funcs,
q_function_optim=critic_optim.optim,
policy=policy,
policy_optim=actor_optim.optim,
device=self._device,
)

Expand Down
50 changes: 42 additions & 8 deletions d3rlpy/algos/qlearning/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.ddpg_impl import DDPGActionSampler, DDPGValuePredictor
from .torch.sac_impl import (
DiscreteSACImpl,
DiscreteSACModules,
SACImpl,
SACModules,
SACActionSampler,
SACCriticLossFn,
SACActorLossFn,
SACUpdater,
)
from .functional import FunctionalQLearningAlgoImplBase

__all__ = ["SACConfig", "SAC", "DiscreteSACConfig", "DiscreteSAC"]

Expand Down Expand Up @@ -122,7 +127,7 @@ def get_type() -> str:
return "sac"


class SAC(QLearningAlgoBase[SACImpl, SACConfig]):
class SAC(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, SACConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
Expand Down Expand Up @@ -187,15 +192,44 @@ def inner_create_impl(
temp_optim=temp_optim,
)

self._impl = SACImpl(
updater = SACUpdater(
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
critic_optim=critic_optim,
actor_optim=actor_optim,
critic_loss_fn=SACCriticLossFn(
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
policy=policy,
log_temp=log_temp,
gamma=self._config.gamma,
),
actor_loss_fn=SACActorLossFn(
q_func_forwarder=q_func_forwarder,
policy=policy,
log_temp=log_temp,
temp_optim=temp_optim,
action_size=action_size,
),
tau=self._config.tau,
compiled=self.compiled,
)
exploit_action_sampler = DDPGActionSampler(policy)
explore_action_sampler = SACActionSampler(policy)
value_predictor = DDPGValuePredictor(q_func_forwarder)

self._impl = FunctionalQLearningAlgoImplBase(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
compiled=self.compiled,
updater=updater,
exploit_action_sampler=exploit_action_sampler,
explore_action_sampler=explore_action_sampler,
value_predictor=value_predictor,
q_function=q_funcs,
q_function_optim=critic_optim.optim,
policy=policy,
policy_optim=actor_optim.optim,
device=self._device,
)

Expand Down
51 changes: 39 additions & 12 deletions d3rlpy/algos/qlearning/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.ddpg_impl import DDPGModules
from .torch.td3_impl import TD3Impl
from .functional import FunctionalQLearningAlgoImplBase
from .torch.ddpg_impl import DDPGModules, DDPGActionSampler, DDPGValuePredictor, DDPGActorLossFn
from .torch.td3_impl import TD3CriticLossFn, TD3Updater

__all__ = ["TD3Config", "TD3"]

Expand Down Expand Up @@ -102,7 +103,7 @@ def get_type() -> str:
return "td3"


class TD3(QLearningAlgoBase[TD3Impl, TD3Config]):
class TD3(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, TD3Config]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
Expand Down Expand Up @@ -159,18 +160,44 @@ def inner_create_impl(
critic_optim=critic_optim,
)

self._impl = TD3Impl(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
updater = TD3Updater(
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
policy=policy,
targ_policy=targ_policy,
critic_optim=critic_optim,
actor_optim=actor_optim,
critic_loss_fn=TD3CriticLossFn(
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
targ_policy=targ_policy,
gamma=self._config.gamma,
target_smoothing_sigma=self._config.target_smoothing_sigma,
target_smoothing_clip=self._config.target_smoothing_clip,
),
actor_loss_fn=DDPGActorLossFn(
q_func_forwarder=q_func_forwarder,
policy=policy,
),
tau=self._config.tau,
target_smoothing_sigma=self._config.target_smoothing_sigma,
target_smoothing_clip=self._config.target_smoothing_clip,
update_actor_interval=self._config.update_actor_interval,
compiled=self.compiled,
)
action_sampler = DDPGActionSampler(policy)
value_predictor = DDPGValuePredictor(q_func_forwarder)

self._impl = FunctionalQLearningAlgoImplBase(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
updater=updater,
exploit_action_sampler=action_sampler,
explore_action_sampler=action_sampler,
value_predictor=value_predictor,
q_function=q_funcs,
q_function_optim=critic_optim.optim,
policy=policy,
policy_optim=actor_optim.optim,
device=self._device,
)

Expand Down
54 changes: 41 additions & 13 deletions d3rlpy/algos/qlearning/td3_plus_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.ddpg_impl import DDPGModules
from .torch.td3_plus_bc_impl import TD3PlusBCImpl
from .functional import FunctionalQLearningAlgoImplBase
from .torch.ddpg_impl import DDPGModules, DDPGValuePredictor, DDPGActionSampler
from .torch.td3_impl import TD3CriticLossFn, TD3Updater
from .torch.td3_plus_bc_impl import TD3PlusBCActorLossFn

__all__ = ["TD3PlusBCConfig", "TD3PlusBC"]

Expand Down Expand Up @@ -94,7 +96,7 @@ def get_type() -> str:
return "td3_plus_bc"


class TD3PlusBC(QLearningAlgoBase[TD3PlusBCImpl, TD3PlusBCConfig]):
class TD3PlusBC(QLearningAlgoBase[FunctionalQLearningAlgoImplBase, TD3PlusBCConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
Expand Down Expand Up @@ -151,19 +153,45 @@ def inner_create_impl(
critic_optim=critic_optim,
)

self._impl = TD3PlusBCImpl(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
updater = TD3Updater(
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
policy=policy,
targ_policy=targ_policy,
critic_optim=critic_optim,
actor_optim=actor_optim,
critic_loss_fn=TD3CriticLossFn(
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
targ_policy=targ_policy,
gamma=self._config.gamma,
target_smoothing_sigma=self._config.target_smoothing_sigma,
target_smoothing_clip=self._config.target_smoothing_clip,
),
actor_loss_fn=TD3PlusBCActorLossFn(
q_func_forwarder=q_func_forwarder,
policy=policy,
alpha=self._config.alpha,
),
tau=self._config.tau,
target_smoothing_sigma=self._config.target_smoothing_sigma,
target_smoothing_clip=self._config.target_smoothing_clip,
alpha=self._config.alpha,
update_actor_interval=self._config.update_actor_interval,
compiled=self.compiled,
)
action_sampler = DDPGActionSampler(policy)
value_predictor = DDPGValuePredictor(q_func_forwarder)

self._impl = FunctionalQLearningAlgoImplBase(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
updater=updater,
exploit_action_sampler=action_sampler,
explore_action_sampler=action_sampler,
value_predictor=value_predictor,
q_function=q_funcs,
q_function_optim=critic_optim.optim,
policy=policy,
policy_optim=actor_optim.optim,
device=self._device,
)

Expand Down
Loading

0 comments on commit 532d87c

Please sign in to comment.