Skip to content

Commit

Permalink
add PRDC (#431)
Browse files Browse the repository at this point in the history
* add PRDC

* update README.md

* Remove useless imports

* Fix errors from lint
  • Loading branch information
liyc-ai authored Nov 8, 2024
1 parent e660d23 commit 8c68ea6
Show file tree
Hide file tree
Showing 10 changed files with 434 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ build
dist
/.idea/
*.egg-info
/.vscode/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ $ docker run -it --gpus all --name d3rlpy takuseno/d3rlpy:latest bash
| [Critic Reguralized Regression (CRR)](https://arxiv.org/abs/2006.15134) | :no_entry: | :white_check_mark: |
| [Policy in Latent Action Space (PLAS)](https://arxiv.org/abs/2011.07213) | :no_entry: | :white_check_mark: |
| [TD3+BC](https://arxiv.org/abs/2106.06860) | :no_entry: | :white_check_mark: |
| [Policy Regularization with Dataset Constraint (PRDC)](https://arxiv.org/abs/2306.06569) | :no_entry: | :white_check_mark: |
| [Implicit Q-Learning (IQL)](https://arxiv.org/abs/2110.06169) | :no_entry: | :white_check_mark: |
| [Calibrated Q-Learning (Cal-QL)](https://arxiv.org/abs/2303.05479) | :no_entry: | :white_check_mark: |
| [ReBRAC](https://arxiv.org/abs/2305.09836) | :no_entry: | :white_check_mark: |
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .iql import *
from .nfq import *
from .plas import *
from .prdc import *
from .random_policy import *
from .rebrac import *
from .sac import *
Expand Down
259 changes: 259 additions & 0 deletions d3rlpy/algos/qlearning/prdc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import dataclasses
from typing import Callable, Optional

import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from typing_extensions import Self

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace, LoggingStrategy
from ...dataset import ReplayBufferBase
from ...logging import FileAdapterFactory, LoggerAdapterFactory
from ...metrics import EvaluatorProtocol
from ...models.builders import (
create_continuous_q_function,
create_deterministic_policy,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from ..utility import build_scalers_with_transition_picker
from .base import QLearningAlgoBase
from .torch.ddpg_impl import DDPGModules
from .torch.prdc_impl import PRDCImpl

__all__ = ["PRDCConfig", "PRDC"]


@dataclasses.dataclass()
class PRDCConfig(LearnableConfig):
r"""Config of PRDC algorithm.
PRDC is an simple offline RL algorithm built on top of TD3.
PRDC introduces Dataset Constraint (DC)-reguralized policy objective function.
.. math::
J(\phi) = \mathbb{E}_{s \sim D}
[\lambda Q(s, \pi(s)) - d^\beta_D(s, \pi(s))]
where
.. math::
\lambda = \frac{\alpha}{\frac{1}{N} \sum_(s_i, a_i) |Q(s_i, a_i)|}
and `d^\beta_\mathcal{D}(s,\pi(s))` is the DC loss, defined as
.. math::
d^\beta_\mathcal{D}(s,\pi(s)) = \min_{\hat{s}, \hat{a} \sim D}
[\| (\beta s) \oplus \pi(s) - (\beta \hat{s}) \oplus \hat{a} \|]
References:
* `Ran et al., Policy Regularization with Dataset Constraint for Offline Reinforcement Learning
Learning. <https://arxiv.org/abs/2306.06569>`_
Args:
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor.
actor_learning_rate (float): Learning rate for a policy function.
critic_learning_rate (float): Learning rate for Q functions.
actor_optim_factory (d3rlpy.optimizers.OptimizerFactory):
Optimizer factory for the actor.
critic_optim_factory (d3rlpy.optimizers.OptimizerFactory):
Optimizer factory for the critic.
actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for the actor.
critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for the critic.
q_func_factory (d3rlpy.models.q_functions.QFunctionFactory):
Q function factory.
batch_size (int): Mini-batch size.
gamma (float): Discount factor.
tau (float): Target network synchronization coefficiency.
n_critics (int): Number of Q functions for ensemble.
target_smoothing_sigma (float): Standard deviation for target noise.
target_smoothing_clip (float): Clipping range for target noise.
alpha (float): :math:`\alpha` value.
beta (float): :math:`\beta` value.
update_actor_interval (int): Interval to update policy function
described as `delayed policy update` in the paper.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4
actor_optim_factory: OptimizerFactory = make_optimizer_field()
critic_optim_factory: OptimizerFactory = make_optimizer_field()
actor_encoder_factory: EncoderFactory = make_encoder_field()
critic_encoder_factory: EncoderFactory = make_encoder_field()
q_func_factory: QFunctionFactory = make_q_func_field()
batch_size: int = 256
gamma: float = 0.99
tau: float = 0.005
n_critics: int = 2
target_smoothing_sigma: float = 0.2
target_smoothing_clip: float = 0.5
alpha: float = 2.5
beta: float = 2.0
update_actor_interval: int = 2

def create(self, device: DeviceArg = False, enable_ddp: bool = False) -> "PRDC":
return PRDC(self, device, enable_ddp)

@staticmethod
def get_type() -> str:
return "prdc"


class PRDC(QLearningAlgoBase[PRDCImpl, PRDCConfig]):
_nbsr = NearestNeighbors(n_neighbors=1, algorithm="auto", n_jobs=-1)

def inner_create_impl(self, observation_shape: Shape, action_size: int) -> None:
policy = create_deterministic_policy(
observation_shape,
action_size,
self._config.actor_encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
)
targ_policy = create_deterministic_policy(
observation_shape,
action_size,
self._config.actor_encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
)
q_funcs, q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=self.compiled,
)

modules = DDPGModules(
policy=policy,
targ_policy=targ_policy,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
actor_optim=actor_optim,
critic_optim=critic_optim,
)

self._impl = PRDCImpl(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
target_smoothing_sigma=self._config.target_smoothing_sigma,
target_smoothing_clip=self._config.target_smoothing_clip,
alpha=self._config.alpha,
beta=self._config.beta,
update_actor_interval=self._config.update_actor_interval,
compiled=self.compiled,
nbsr=self._nbsr,
device=self._device,
)

def fit(
self,
dataset: ReplayBufferBase,
n_steps: int,
n_steps_per_epoch: int = 10000,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logging_steps: int = 500,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
save_interval: int = 1,
evaluators: Optional[dict[str, EvaluatorProtocol]] = None,
callback: Optional[Callable[[Self, int, int], None]] = None,
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
) -> list[tuple[int, dict[str, float]]]:
observations = []
actions = []
for episode in dataset.buffer.episodes:
for i in range(episode.transition_count):
transition = dataset.transition_picker(episode, i)
observations.append(np.reshape(transition.observation, (1, -1)))
actions.append(np.reshape(transition.action, (1, -1)))
observations = np.concatenate(observations, axis=0)
actions = np.concatenate(actions, axis=0)

build_scalers_with_transition_picker(self, dataset)
if self.observation_scaler and self.observation_scaler.built:
observations = (
self.observation_scaler.transform(
torch.tensor(observations, device=self._device)
)
.cpu()
.numpy()
)

if self.action_scaler and self.action_scaler.built:
actions = (
self.action_scaler.transform(torch.tensor(actions, device=self._device))
.cpu()
.numpy()
)

self._nbsr.fit(
np.concatenate(
[np.multiply(observations, self._config.beta), actions],
axis=1,
)
)

return super().fit(
dataset=dataset,
n_steps=n_steps,
n_steps_per_epoch=n_steps_per_epoch,
logging_steps=logging_steps,
logging_strategy=logging_strategy,
experiment_name=experiment_name,
with_timestamp=with_timestamp,
logger_adapter=logger_adapter,
show_progress=show_progress,
save_interval=save_interval,
evaluators=evaluators,
callback=callback,
epoch_callback=epoch_callback,
)

def get_action_type(self) -> ActionSpace:
return ActionSpace.CONTINUOUS


register_learnable(PRDCConfig)
1 change: 1 addition & 0 deletions d3rlpy/algos/qlearning/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .dqn_impl import *
from .iql_impl import *
from .plas_impl import *
from .prdc_impl import *
from .rebrac_impl import *
from .sac_impl import *
from .td3_impl import *
Expand Down
84 changes: 84 additions & 0 deletions d3rlpy/algos/qlearning/torch/prdc_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# pylint: disable=too-many-ancestors
import dataclasses

import torch
from sklearn.neighbors import NearestNeighbors

from ....models.torch import ActionOutput, ContinuousEnsembleQFunctionForwarder
from ....torch_utility import TorchMiniBatch
from ....types import Shape
from .ddpg_impl import DDPGBaseActorLoss, DDPGModules
from .td3_impl import TD3Impl

__all__ = ["PRDCImpl"]


@dataclasses.dataclass(frozen=True)
class PRDCActorLoss(DDPGBaseActorLoss):
dc_loss: torch.Tensor


class PRDCImpl(TD3Impl):
_alpha: float
_beta: float
_nbsr: NearestNeighbors

def __init__(
self,
observation_shape: Shape,
action_size: int,
modules: DDPGModules,
q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
gamma: float,
tau: float,
target_smoothing_sigma: float,
target_smoothing_clip: float,
alpha: float,
beta: float,
update_actor_interval: int,
compiled: bool,
nbsr: NearestNeighbors,
device: str,
):
super().__init__(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=gamma,
tau=tau,
target_smoothing_sigma=target_smoothing_sigma,
target_smoothing_clip=target_smoothing_clip,
update_actor_interval=update_actor_interval,
compiled=compiled,
device=device,
)
self._alpha = alpha
self._beta = beta
self._nbsr = nbsr

def compute_actor_loss(
self, batch: TorchMiniBatch, action: ActionOutput
) -> PRDCActorLoss:
q_t = self._q_func_forwarder.compute_expected_q(
batch.observations, action.squashed_mu, "none"
)[0]
lam = self._alpha / (q_t.abs().mean()).detach()
key = (
torch.cat(
[torch.mul(batch.observations, self._beta), action.squashed_mu], dim=-1
)
.detach()
.cpu()
.numpy()
)
idx = self._nbsr.kneighbors(key, n_neighbors=1, return_distance=False)
nearest_neightbour = torch.tensor(
self._nbsr._fit_X[idx][:, :, -self.action_size :],
device=self.device,
dtype=action.squashed_mu.dtype,
).squeeze(dim=1)
dc_loss = torch.nn.functional.mse_loss(action.squashed_mu, nearest_neightbour)
return PRDCActorLoss(actor_loss=lam * -q_t.mean() + dc_loss, dc_loss=dc_loss)
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@ follow_imports_for_stubs = True
ignore_missing_imports = True
follow_imports = skip
follow_imports_for_stubs = True

[mypy-sklearn.*]
ignore_missing_imports = True
Loading

0 comments on commit 8c68ea6

Please sign in to comment.