-
Notifications
You must be signed in to change notification settings - Fork 244
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add PRDC * update README.md * Remove useless imports * Fix errors from lint
- Loading branch information
Showing
10 changed files
with
434 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,3 +21,4 @@ build | |
dist | ||
/.idea/ | ||
*.egg-info | ||
/.vscode/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
import dataclasses | ||
from typing import Callable, Optional | ||
|
||
import numpy as np | ||
import torch | ||
from sklearn.neighbors import NearestNeighbors | ||
from typing_extensions import Self | ||
|
||
from ...base import DeviceArg, LearnableConfig, register_learnable | ||
from ...constants import ActionSpace, LoggingStrategy | ||
from ...dataset import ReplayBufferBase | ||
from ...logging import FileAdapterFactory, LoggerAdapterFactory | ||
from ...metrics import EvaluatorProtocol | ||
from ...models.builders import ( | ||
create_continuous_q_function, | ||
create_deterministic_policy, | ||
) | ||
from ...models.encoders import EncoderFactory, make_encoder_field | ||
from ...models.q_functions import QFunctionFactory, make_q_func_field | ||
from ...optimizers.optimizers import OptimizerFactory, make_optimizer_field | ||
from ...types import Shape | ||
from ..utility import build_scalers_with_transition_picker | ||
from .base import QLearningAlgoBase | ||
from .torch.ddpg_impl import DDPGModules | ||
from .torch.prdc_impl import PRDCImpl | ||
|
||
__all__ = ["PRDCConfig", "PRDC"] | ||
|
||
|
||
@dataclasses.dataclass() | ||
class PRDCConfig(LearnableConfig): | ||
r"""Config of PRDC algorithm. | ||
PRDC is an simple offline RL algorithm built on top of TD3. | ||
PRDC introduces Dataset Constraint (DC)-reguralized policy objective function. | ||
.. math:: | ||
J(\phi) = \mathbb{E}_{s \sim D} | ||
[\lambda Q(s, \pi(s)) - d^\beta_D(s, \pi(s))] | ||
where | ||
.. math:: | ||
\lambda = \frac{\alpha}{\frac{1}{N} \sum_(s_i, a_i) |Q(s_i, a_i)|} | ||
and `d^\beta_\mathcal{D}(s,\pi(s))` is the DC loss, defined as | ||
.. math:: | ||
d^\beta_\mathcal{D}(s,\pi(s)) = \min_{\hat{s}, \hat{a} \sim D} | ||
[\| (\beta s) \oplus \pi(s) - (\beta \hat{s}) \oplus \hat{a} \|] | ||
References: | ||
* `Ran et al., Policy Regularization with Dataset Constraint for Offline Reinforcement Learning | ||
Learning. <https://arxiv.org/abs/2306.06569>`_ | ||
Args: | ||
observation_scaler (d3rlpy.preprocessing.ObservationScaler): | ||
Observation preprocessor. | ||
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor. | ||
reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor. | ||
actor_learning_rate (float): Learning rate for a policy function. | ||
critic_learning_rate (float): Learning rate for Q functions. | ||
actor_optim_factory (d3rlpy.optimizers.OptimizerFactory): | ||
Optimizer factory for the actor. | ||
critic_optim_factory (d3rlpy.optimizers.OptimizerFactory): | ||
Optimizer factory for the critic. | ||
actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory): | ||
Encoder factory for the actor. | ||
critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory): | ||
Encoder factory for the critic. | ||
q_func_factory (d3rlpy.models.q_functions.QFunctionFactory): | ||
Q function factory. | ||
batch_size (int): Mini-batch size. | ||
gamma (float): Discount factor. | ||
tau (float): Target network synchronization coefficiency. | ||
n_critics (int): Number of Q functions for ensemble. | ||
target_smoothing_sigma (float): Standard deviation for target noise. | ||
target_smoothing_clip (float): Clipping range for target noise. | ||
alpha (float): :math:`\alpha` value. | ||
beta (float): :math:`\beta` value. | ||
update_actor_interval (int): Interval to update policy function | ||
described as `delayed policy update` in the paper. | ||
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. | ||
""" | ||
|
||
actor_learning_rate: float = 3e-4 | ||
critic_learning_rate: float = 3e-4 | ||
actor_optim_factory: OptimizerFactory = make_optimizer_field() | ||
critic_optim_factory: OptimizerFactory = make_optimizer_field() | ||
actor_encoder_factory: EncoderFactory = make_encoder_field() | ||
critic_encoder_factory: EncoderFactory = make_encoder_field() | ||
q_func_factory: QFunctionFactory = make_q_func_field() | ||
batch_size: int = 256 | ||
gamma: float = 0.99 | ||
tau: float = 0.005 | ||
n_critics: int = 2 | ||
target_smoothing_sigma: float = 0.2 | ||
target_smoothing_clip: float = 0.5 | ||
alpha: float = 2.5 | ||
beta: float = 2.0 | ||
update_actor_interval: int = 2 | ||
|
||
def create(self, device: DeviceArg = False, enable_ddp: bool = False) -> "PRDC": | ||
return PRDC(self, device, enable_ddp) | ||
|
||
@staticmethod | ||
def get_type() -> str: | ||
return "prdc" | ||
|
||
|
||
class PRDC(QLearningAlgoBase[PRDCImpl, PRDCConfig]): | ||
_nbsr = NearestNeighbors(n_neighbors=1, algorithm="auto", n_jobs=-1) | ||
|
||
def inner_create_impl(self, observation_shape: Shape, action_size: int) -> None: | ||
policy = create_deterministic_policy( | ||
observation_shape, | ||
action_size, | ||
self._config.actor_encoder_factory, | ||
device=self._device, | ||
enable_ddp=self._enable_ddp, | ||
) | ||
targ_policy = create_deterministic_policy( | ||
observation_shape, | ||
action_size, | ||
self._config.actor_encoder_factory, | ||
device=self._device, | ||
enable_ddp=self._enable_ddp, | ||
) | ||
q_funcs, q_func_forwarder = create_continuous_q_function( | ||
observation_shape, | ||
action_size, | ||
self._config.critic_encoder_factory, | ||
self._config.q_func_factory, | ||
n_ensembles=self._config.n_critics, | ||
device=self._device, | ||
enable_ddp=self._enable_ddp, | ||
) | ||
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( | ||
observation_shape, | ||
action_size, | ||
self._config.critic_encoder_factory, | ||
self._config.q_func_factory, | ||
n_ensembles=self._config.n_critics, | ||
device=self._device, | ||
enable_ddp=self._enable_ddp, | ||
) | ||
|
||
actor_optim = self._config.actor_optim_factory.create( | ||
policy.named_modules(), | ||
lr=self._config.actor_learning_rate, | ||
compiled=self.compiled, | ||
) | ||
critic_optim = self._config.critic_optim_factory.create( | ||
q_funcs.named_modules(), | ||
lr=self._config.critic_learning_rate, | ||
compiled=self.compiled, | ||
) | ||
|
||
modules = DDPGModules( | ||
policy=policy, | ||
targ_policy=targ_policy, | ||
q_funcs=q_funcs, | ||
targ_q_funcs=targ_q_funcs, | ||
actor_optim=actor_optim, | ||
critic_optim=critic_optim, | ||
) | ||
|
||
self._impl = PRDCImpl( | ||
observation_shape=observation_shape, | ||
action_size=action_size, | ||
modules=modules, | ||
q_func_forwarder=q_func_forwarder, | ||
targ_q_func_forwarder=targ_q_func_forwarder, | ||
gamma=self._config.gamma, | ||
tau=self._config.tau, | ||
target_smoothing_sigma=self._config.target_smoothing_sigma, | ||
target_smoothing_clip=self._config.target_smoothing_clip, | ||
alpha=self._config.alpha, | ||
beta=self._config.beta, | ||
update_actor_interval=self._config.update_actor_interval, | ||
compiled=self.compiled, | ||
nbsr=self._nbsr, | ||
device=self._device, | ||
) | ||
|
||
def fit( | ||
self, | ||
dataset: ReplayBufferBase, | ||
n_steps: int, | ||
n_steps_per_epoch: int = 10000, | ||
experiment_name: Optional[str] = None, | ||
with_timestamp: bool = True, | ||
logging_steps: int = 500, | ||
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, | ||
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), | ||
show_progress: bool = True, | ||
save_interval: int = 1, | ||
evaluators: Optional[dict[str, EvaluatorProtocol]] = None, | ||
callback: Optional[Callable[[Self, int, int], None]] = None, | ||
epoch_callback: Optional[Callable[[Self, int, int], None]] = None, | ||
) -> list[tuple[int, dict[str, float]]]: | ||
observations = [] | ||
actions = [] | ||
for episode in dataset.buffer.episodes: | ||
for i in range(episode.transition_count): | ||
transition = dataset.transition_picker(episode, i) | ||
observations.append(np.reshape(transition.observation, (1, -1))) | ||
actions.append(np.reshape(transition.action, (1, -1))) | ||
observations = np.concatenate(observations, axis=0) | ||
actions = np.concatenate(actions, axis=0) | ||
|
||
build_scalers_with_transition_picker(self, dataset) | ||
if self.observation_scaler and self.observation_scaler.built: | ||
observations = ( | ||
self.observation_scaler.transform( | ||
torch.tensor(observations, device=self._device) | ||
) | ||
.cpu() | ||
.numpy() | ||
) | ||
|
||
if self.action_scaler and self.action_scaler.built: | ||
actions = ( | ||
self.action_scaler.transform(torch.tensor(actions, device=self._device)) | ||
.cpu() | ||
.numpy() | ||
) | ||
|
||
self._nbsr.fit( | ||
np.concatenate( | ||
[np.multiply(observations, self._config.beta), actions], | ||
axis=1, | ||
) | ||
) | ||
|
||
return super().fit( | ||
dataset=dataset, | ||
n_steps=n_steps, | ||
n_steps_per_epoch=n_steps_per_epoch, | ||
logging_steps=logging_steps, | ||
logging_strategy=logging_strategy, | ||
experiment_name=experiment_name, | ||
with_timestamp=with_timestamp, | ||
logger_adapter=logger_adapter, | ||
show_progress=show_progress, | ||
save_interval=save_interval, | ||
evaluators=evaluators, | ||
callback=callback, | ||
epoch_callback=epoch_callback, | ||
) | ||
|
||
def get_action_type(self) -> ActionSpace: | ||
return ActionSpace.CONTINUOUS | ||
|
||
|
||
register_learnable(PRDCConfig) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# pylint: disable=too-many-ancestors | ||
import dataclasses | ||
|
||
import torch | ||
from sklearn.neighbors import NearestNeighbors | ||
|
||
from ....models.torch import ActionOutput, ContinuousEnsembleQFunctionForwarder | ||
from ....torch_utility import TorchMiniBatch | ||
from ....types import Shape | ||
from .ddpg_impl import DDPGBaseActorLoss, DDPGModules | ||
from .td3_impl import TD3Impl | ||
|
||
__all__ = ["PRDCImpl"] | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class PRDCActorLoss(DDPGBaseActorLoss): | ||
dc_loss: torch.Tensor | ||
|
||
|
||
class PRDCImpl(TD3Impl): | ||
_alpha: float | ||
_beta: float | ||
_nbsr: NearestNeighbors | ||
|
||
def __init__( | ||
self, | ||
observation_shape: Shape, | ||
action_size: int, | ||
modules: DDPGModules, | ||
q_func_forwarder: ContinuousEnsembleQFunctionForwarder, | ||
targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, | ||
gamma: float, | ||
tau: float, | ||
target_smoothing_sigma: float, | ||
target_smoothing_clip: float, | ||
alpha: float, | ||
beta: float, | ||
update_actor_interval: int, | ||
compiled: bool, | ||
nbsr: NearestNeighbors, | ||
device: str, | ||
): | ||
super().__init__( | ||
observation_shape=observation_shape, | ||
action_size=action_size, | ||
modules=modules, | ||
q_func_forwarder=q_func_forwarder, | ||
targ_q_func_forwarder=targ_q_func_forwarder, | ||
gamma=gamma, | ||
tau=tau, | ||
target_smoothing_sigma=target_smoothing_sigma, | ||
target_smoothing_clip=target_smoothing_clip, | ||
update_actor_interval=update_actor_interval, | ||
compiled=compiled, | ||
device=device, | ||
) | ||
self._alpha = alpha | ||
self._beta = beta | ||
self._nbsr = nbsr | ||
|
||
def compute_actor_loss( | ||
self, batch: TorchMiniBatch, action: ActionOutput | ||
) -> PRDCActorLoss: | ||
q_t = self._q_func_forwarder.compute_expected_q( | ||
batch.observations, action.squashed_mu, "none" | ||
)[0] | ||
lam = self._alpha / (q_t.abs().mean()).detach() | ||
key = ( | ||
torch.cat( | ||
[torch.mul(batch.observations, self._beta), action.squashed_mu], dim=-1 | ||
) | ||
.detach() | ||
.cpu() | ||
.numpy() | ||
) | ||
idx = self._nbsr.kneighbors(key, n_neighbors=1, return_distance=False) | ||
nearest_neightbour = torch.tensor( | ||
self._nbsr._fit_X[idx][:, :, -self.action_size :], | ||
device=self.device, | ||
dtype=action.squashed_mu.dtype, | ||
).squeeze(dim=1) | ||
dc_loss = torch.nn.functional.mse_loss(action.squashed_mu, nearest_neightbour) | ||
return PRDCActorLoss(actor_loss=lam * -q_t.mean() + dc_loss, dc_loss=dc_loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.