diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index 0b944852d8..0d7694e737 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -318,8 +318,8 @@ def collect_mode(self) -> 'Policy.collect_function': # noqa to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \ subclass can override the interfaces to customize its own collect mode. Returns: - - interfaces (:obj:`Policy.collect_function`): The interfaces of collect mode of policy, it is a namedtuple \ - whose values of distinct fields are different internal methods. + - interfaces (:obj:`Policy.collect_function`): The interfaces of collect mode of policy, it is a \ + namedtuple whose values of distinct fields are different internal methods. Examples: >>> policy = Policy(cfg, model) >>> policy_collect = policy.collect_mode diff --git a/ding/policy/common_utils.py b/ding/policy/common_utils.py index d04abf3658..fd2c7d3d61 100644 --- a/ding/policy/common_utils.py +++ b/ding/policy/common_utils.py @@ -1,4 +1,4 @@ -from typing import List, Any, Dict +from typing import List, Any, Dict, Callable import torch import numpy as np import treetensor.torch as ttorch @@ -56,7 +56,7 @@ def default_preprocess_learn( else: data['weight'] = data.get('weight', None) if use_nstep: - # Reward reshaping for n-step + # reward reshaping for n-step reward = data['reward'] if len(reward.shape) == 1: reward = reward.unsqueeze(1) @@ -69,10 +69,22 @@ def default_preprocess_learn( return data -def single_env_forward_wrapper(forward_fn): +def single_env_forward_wrapper(forward_fn: Callable) -> Callable: """ Overview: - Wrap policy to support gym-style interaction between policy and environment. + Wrap policy to support gym-style interaction between policy and single environment. + Arguments: + - forward_fn (:obj:`Callable`): The original forward function of policy. + Returns: + - wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy. + Examples: + >>> env = gym.make('CartPole-v0') + >>> policy = DQNPolicy(...) + >>> forward_fn = single_env_forward_wrapper(policy.eval_mode.forward) + >>> obs = env.reset() + >>> action = forward_fn(obs) + >>> next_obs, rew, done, info = env.step(action) + """ def _forward(obs): @@ -84,10 +96,23 @@ def _forward(obs): return _forward -def single_env_forward_wrapper_ttorch(forward_fn, cuda=True): +def single_env_forward_wrapper_ttorch(forward_fn: Callable, cuda: bool = True) -> Callable: """ Overview: - Wrap policy to support gym-style interaction between policy and environment for treetensor (ttorch) data. + Wrap policy to support gym-style interaction between policy and single environment for treetensor (ttorch) data. + Arguments: + - forward_fn (:obj:`Callable`): The original forward function of policy. + - cuda (:obj:`bool`): Whether to use cuda in policy, if True, this function will move the input data to cuda. + Returns: + - wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy. + + Examples: + >>> env = gym.make('CartPole-v0') + >>> policy = PPOFPolicy(...) + >>> forward_fn = single_env_forward_wrapper_ttorch(policy.eval) + >>> obs = env.reset() + >>> action = forward_fn(obs) + >>> next_obs, rew, done, info = env.step(action) """ def _forward(obs): diff --git a/ding/policy/ddpg.py b/ding/policy/ddpg.py index e68dec62b1..2ddb2c4af5 100644 --- a/ding/policy/ddpg.py +++ b/ding/policy/ddpg.py @@ -64,8 +64,7 @@ class DDPGPolicy(Policy): type='ddpg', # (bool) Whether to use cuda in policy. cuda=False, - # (bool) Whether learning policy is the same as collecting data policy(on-policy). - # Default False in DDPG. + # (bool) Whether learning policy is the same as collecting data policy(on-policy). Default False in DDPG. on_policy=False, # (bool) Whether to enable priority experience sample. priority=False, @@ -84,7 +83,7 @@ class DDPGPolicy(Policy): multi_agent=False, # learn_mode config learn=dict( - # How many updates(iterations) to train after collector's one collection. + # (int) How many updates(iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. # collect data -> update policy-> collect data -> ... update_per_collect=1, @@ -150,7 +149,7 @@ def default_model(self) -> Tuple[str, List[str]]: return 'continuous_qac', ['ding.model.template.qac'] def _init_learn(self) -> None: - r""" + """ Overview: Learn mode init method. Called by ``self.__init__``. Init actor and critic optimizers, algorithm config, main and target models. @@ -202,7 +201,7 @@ def _init_learn(self) -> None: self._forward_learn_cnt = 0 # count iterations def _forward_learn(self, data: dict) -> Dict[str, Any]: - r""" + """ Overview: Forward and backward function of learn mode. Arguments: @@ -343,7 +342,7 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: self._optimizer_critic.load_state_dict(state_dict['optimizer_critic']) def _init_collect(self) -> None: - r""" + """ Overview: Collect mode init method. Called by ``self.__init__``. Init traj and unroll length, collect model. @@ -365,7 +364,7 @@ def _init_collect(self) -> None: self._collect_model.reset() def _forward_collect(self, data: dict, **kwargs) -> dict: - r""" + """ Overview: Forward function of collect mode. Arguments: @@ -431,7 +430,7 @@ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, return get_train_sample(transitions, self._unroll_len) def _init_eval(self) -> None: - r""" + """ Overview: Evaluate mode init method. Called by ``self.__init__``. Init eval model. Unlike learn and collect model, eval model does not need noise. diff --git a/ding/policy/policy_factory.py b/ding/policy/policy_factory.py index d5e0a5da98..fa682b9c3c 100644 --- a/ding/policy/policy_factory.py +++ b/ding/policy/policy_factory.py @@ -8,17 +8,32 @@ class PolicyFactory: - r""" + """ Overview: - Pure random policy. Only used for initial sample collecting if `cfg.policy.random_collect_size` > 0. + Policy factory class, used to generate different policies for general purpose. Such as random action policy, \ + which is used for initial sample collecting for better exploration when ``random_collect_size`` > 0. + Interfaces: + ``get_random_policy`` """ @staticmethod def get_random_policy( - policy: 'BasePolicy', # noqa + policy: 'Policy.collect_mode', # noqa action_space: 'gym.spaces.Space' = None, # noqa forward_fn: Callable = None, - ) -> None: + ) -> 'Policy.collect_mode': # noqa + """ + Overview: + According to the given action space, define the forward function of the random policy, then pack it with \ + other interfaces of the given policy, and return the final collect mode interfaces of policy. + Arguments: + - policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy. + - action_space (:obj:`gym.spaces.Space`): The action space of the environment, gym-style. + - forward_fn (:obj:`Callable`): It action space is too complex, you can define your own forward function \ + and pass it to this function, note you should set ``action_space`` to ``None`` in this case. + Returns: + - random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy. + """ assert not (action_space is None and forward_fn is None) random_collect_function = namedtuple( 'random_collect_function', [ @@ -69,7 +84,23 @@ def reset(*args, **kwargs) -> None: ) -def get_random_policy(cfg: EasyDict, policy: 'Policy.collect_mode', env: 'BaseEnvManager'): # noqa +def get_random_policy( + cfg: EasyDict, + policy: 'Policy.collect_mode', # noqa + env: 'BaseEnvManager' # noqa +) -> 'Policy.collect_mode': # noqa + """ + Overview: + The entry function to get the corresponding random policy. If a policy needs special data items in a \ + transition, then return itself, otherwise, we will use ``PolicyFactory`` to return a general random policy. + Arguments: + - cfg (:obj:`EasyDict`): The EasyDict-type dict configuration. + - policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy. + - env (:obj:`BaseEnvManager`): The env manager instance, which is used to get the action space for random \ + action generation. + Returns: + - random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy. + """ if cfg.policy.get('transition_with_policy_data', False): return policy else: diff --git a/ding/policy/sac.py b/ding/policy/sac.py index ef633f2340..dd3c64bacd 100644 --- a/ding/policy/sac.py +++ b/ding/policy/sac.py @@ -18,47 +18,10 @@ @POLICY_REGISTRY.register('discrete_sac') class DiscreteSACPolicy(Policy): - r""" - Overview: - Policy class of discrete SAC algorithm. Paper link: https://arxiv.org/pdf/1910.07207.pdf. - - Config: - == ==================== ======== ============= ================================= ======================= - ID Symbol Type Default Value Description Other - == ==================== ======== ============= ================================= ======================= - 1 ``type`` str discrete_sac | RL policy register name, refer | this arg is optional, - | to registry ``POLICY_REGISTRY`` | a placeholder - 2 ``cuda`` bool True | Whether to use cuda for network | - 3 ``on_policy`` bool False | DiscreteSAC is an off-policy | - | algorithm. | - 4 ``priority`` bool False | Whether to use priority | - | sampling in buffer. | - 5 | ``priority_IS_`` bool False | Whether use Importance Sampling | - | ``weight`` | weight to correct biased update | - 6 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for - | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ - | | buffer when training starts. | TD3. - 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3 - | ``_rate_q`` | network. | - 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3 - | ``_rate_policy`` | network. | - 9 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- - | | coefficient. | zation for auto - | | | `\alpha`, when - | | | auto_alpha is True - 10 | ``learn.`` bool False | Determine whether to use | Temperature parameter - | ``auto_alpha`` | auto temperature parameter | determines the - | | `\alpha`. | relative importance - | | | of the entropy term - | | | against the reward. - 11 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only - | ``ignore_done`` | done flag. | in env like Pendulum - 12 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation - | ``target_theta`` | target network. | factor in polyak aver - | | | aging for target - | | | networks. - == ==================== ======== ============= ================================= ======================= - """ + """ + Overview: + Policy class of discrete SAC algorithm. Paper link: https://arxiv.org/abs/1910.07207. + """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). @@ -82,6 +45,7 @@ class DiscreteSACPolicy(Policy): # For more details, please refer to TD3 about Clipped Double-Q Learning trick. twin_critic=True, ), + # learn_mode config learn=dict( # (int) How many updates (iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. @@ -123,8 +87,10 @@ class DiscreteSACPolicy(Policy): # (float) Weight uniform initialization max range in the last output layer init_w=3e-3, ), + # collect_mode config collect=dict( # (int) How many training samples collected in one collection procedure. + # Only one of [n_sample, n_episode] shoule be set. n_sample=1, # (int) Split episodes or trajectories into pieces with length `unroll_len`. unroll_len=1, @@ -132,6 +98,7 @@ class DiscreteSACPolicy(Policy): # In some algorithm like guided cost learning, we need to use logit to train the reward model. collector_logit=False, ), + eval=dict(), # for compability other=dict( replay_buffer=dict( # (int) Maximum size of replay buffer. Usually, larger buffer size is good @@ -142,6 +109,13 @@ class DiscreteSACPolicy(Policy): ) def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + """ if self._cfg.multi_agent: return 'discrete_maqac', ['ding.model.template.maqac'] else: @@ -206,6 +180,10 @@ def _init_learn(self) -> None: self._target_model.reset() def _forward_learn(self, data: dict) -> Dict[str, Any]: + """ + Overview: + Forward function of learn mode. + """ loss_dict = {} data = default_preprocess_learn( data, @@ -332,8 +310,15 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: } def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizers. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ ret = { 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), 'optimizer_q': self._optimizer_q.state_dict(), 'optimizer_policy': self._optimizer_policy.state_dict(), } @@ -342,13 +327,29 @@ def _state_dict_learn(self) -> Dict[str, Any]: return ret def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + + .. tip:: + If you want to only load some parts of model, you can simply set the ``strict`` argument in \ + load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ + complicated operation. + """ self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) self._optimizer_q.load_state_dict(state_dict['optimizer_q']) self._optimizer_policy.load_state_dict(state_dict['optimizer_policy']) if self._auto_alpha: self._alpha_optim.load_state_dict(state_dict['optimizer_alpha']) def _init_collect(self) -> None: + """ + Overview: + Initialize the collect_mode of policy, mainly including collect_model. + """ self._unroll_len = self._cfg.collect.unroll_len # Empirically, we found that eps_greedy_multinomial_sample works better than multinomial_sample # and eps_greedy_sample, and we don't divide logit by alpha, @@ -357,6 +358,10 @@ def _init_collect(self) -> None: self._collect_model.reset() def _forward_collect(self, data: dict, eps: float) -> dict: + """ + Overview: + Forward function for collect_mode. + """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: @@ -369,25 +374,59 @@ def _forward_collect(self, data: dict, eps: float) -> dict: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: + """ + Overview: + Process and pack one timestep transition data info a dict, which can be directly used for training and \ + saved in replay buffer. For discrete SAC, it contains obs, next_obs, logit, action, reward, done. + Arguments: + - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. For discrete SAC, it contains the action and the logit of the action. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. + Returns: + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. + """ transition = { 'obs': obs, 'next_obs': timestep.obs, - 'action': model_output['action'], - 'logit': model_output['logit'], + 'action': policy_output['action'], + 'logit': policy_output['logit'], 'reward': timestep.reward, 'done': timestep.done, } return transition - def _get_train_sample(self, data: list) -> Union[None, List[Any]]: - return get_train_sample(data, self._unroll_len) + def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. In discrete SAC, a train sample is a processed transition (unroll_len=1). + Arguments: + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ + as input transitions, but may contain more data for training. + """ + return get_train_sample(transitions, self._unroll_len) def _init_eval(self) -> None: + """ + Overview: + Initialize the eval_mode of policy, mainly including eval_model. + """ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') self._eval_model.reset() def _forward_eval(self, data: dict) -> dict: + """ + Overview: + Forward function for eval_mode. + """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: @@ -401,6 +440,13 @@ def _forward_eval(self, data: dict) -> dict: return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ twin_critic = ['twin_critic_loss'] if self._twin_critic else [] if self._auto_alpha: return super()._monitor_vars_learn() + [ @@ -416,11 +462,11 @@ def _monitor_vars_learn(self) -> List[str]: @POLICY_REGISTRY.register('sac') class SACPolicy(Policy): - r""" - Overview: - Policy class of continuous SAC algorithm. Paper link: https://arxiv.org/pdf/1801.01290.pdf + """ + Overview: + Policy class of continuous SAC algorithm. Paper link: https://arxiv.org/pdf/1801.01290.pdf - Config: + Config: == ==================== ======== ============= ================================= ======================= ID Symbol Type Default Value Description Other == ==================== ======== ============= ================================= ======================= @@ -442,11 +488,11 @@ class SACPolicy(Policy): | ``_rate_policy`` | network. | 9 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- | | coefficient. | zation for auto - | | | `\alpha`, when + | | | alpha, when | | | auto_alpha is True 10 | ``learn.`` bool False | Determine whether to use | Temperature parameter | ``auto_alpha`` | auto temperature parameter | determines the - | | `\alpha`. | relative importance + | | alpha. | relative importance | | | of the entropy term | | | against the reward. 11 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only @@ -456,7 +502,7 @@ class SACPolicy(Policy): | | | aging for target | | | networks. == ==================== ======== ============= ================================= ======================= - """ + """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). @@ -482,6 +528,7 @@ class SACPolicy(Policy): # (str) Use reparameterization trick for continous action. action_space='reparameterization', ), + # learn_mode config learn=dict( # (int) How many updates (iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. @@ -523,6 +570,7 @@ class SACPolicy(Policy): # (float) Weight uniform initialization max range in the last output layer. init_w=3e-3, ), + # collect_mode config collect=dict( # (int) How many training samples collected in one collection procedure. n_sample=1, @@ -532,6 +580,7 @@ class SACPolicy(Policy): # In some algorithm like guided cost learning, we need to use logit to train the reward model. collector_logit=False, ), + eval=dict(), # for compability other=dict( replay_buffer=dict( # (int) Maximum size of replay buffer. Usually, larger buffer size is good @@ -542,12 +591,23 @@ class SACPolicy(Policy): ) def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + """ if self._cfg.multi_agent: return 'continuous_maqac', ['ding.model.template.maqac'] else: return 'continuous_qac', ['ding.model.template.qac'] def _init_learn(self) -> None: + """ + Overview: + Initialize the learn model and algorithm related object. + """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight self._twin_critic = self._cfg.model.twin_critic @@ -608,6 +668,10 @@ def _init_learn(self) -> None: self._target_model.reset() def _forward_learn(self, data: dict) -> Dict[str, Any]: + """ + Overview: + Forward function of learn mode. + """ loss_dict = {} data = default_preprocess_learn( data, @@ -730,6 +794,12 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: } def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizers. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ ret = { 'model': self._learn_model.state_dict(), 'target_model': self._target_model.state_dict(), @@ -741,6 +811,17 @@ def _state_dict_learn(self) -> Dict[str, Any]: return ret def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + + .. tip:: + If you want to only load some parts of model, you can simply set the ``strict`` argument in \ + load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ + complicated operation. + """ self._learn_model.load_state_dict(state_dict['model']) self._target_model.load_state_dict(state_dict['target_model']) self._optimizer_q.load_state_dict(state_dict['optimizer_q']) @@ -749,11 +830,19 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: self._alpha_optim.load_state_dict(state_dict['optimizer_alpha']) def _init_collect(self) -> None: + """ + Overview: + Initialize the collect mode. + """ self._unroll_len = self._cfg.collect.unroll_len self._collect_model = model_wrap(self._model, wrapper_name='base') self._collect_model.reset() def _forward_collect(self, data: dict) -> dict: + """ + Overview: + Forward function of collect mode. + """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: @@ -769,7 +858,23 @@ def _forward_collect(self, data: dict) -> dict: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def _process_transition(self, obs: Any, policy_output: dict, timestep: namedtuple) -> dict: + def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: + """ + Overview: + Process and pack one timestep transition data info a dict, which can be directly used for training and \ + saved in replay buffer. For continuous SAC, it contains obs, next_obs, action, reward, done. The logit \ + will be also added when ``collector_logit`` is True. + Arguments: + - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. For continuous SAC, it contains the action and the logit (mu and sigma) of the action. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. + Returns: + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. + """ if self._cfg.collect.collector_logit: transition = { 'obs': obs, @@ -789,14 +894,34 @@ def _process_transition(self, obs: Any, policy_output: dict, timestep: namedtupl } return transition - def _get_train_sample(self, data: list) -> Union[None, List[Any]]: - return get_train_sample(data, self._unroll_len) + def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. In continuous SAC, a train sample is a processed transition \ + (unroll_len=1). + Arguments: + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ + as input transitions, but may contain more data for training. + """ + return get_train_sample(transitions, self._unroll_len) def _init_eval(self) -> None: + """ + Overview: + Initialize the eval mode. + """ self._eval_model = model_wrap(self._model, wrapper_name='base') self._eval_model.reset() def _forward_eval(self, data: dict) -> dict: + """ + Overview: + Forward function of eval mode. + """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: @@ -812,6 +937,13 @@ def _forward_eval(self, data: dict) -> dict: return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ twin_critic = ['twin_critic_loss'] if self._twin_critic else [] alpha_loss = ['alpha_loss'] if self._auto_alpha else [] return [ @@ -830,8 +962,18 @@ def _monitor_vars_learn(self) -> List[str]: @POLICY_REGISTRY.register('sqil_sac') class SQILSACPolicy(SACPolicy): + """ + Overview: + Policy class of continuous SAC algorithm with SQIL extension. + SAC paper link: https://arxiv.org/pdf/1801.01290.pdf + SQIL paper link: https://arxiv.org/abs/1905.11108 + """ def _init_learn(self) -> None: + """ + Overview: + Initialize the learn mode. + """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight self._twin_critic = self._cfg.model.twin_critic @@ -896,6 +1038,10 @@ def _init_learn(self) -> None: self._monitor_entropy = True def _forward_learn(self, data: dict) -> Dict[str, Any]: + """ + Overview: + Forward function of learn mode. + """ loss_dict = {} if self._monitor_cos: agent_data = default_preprocess_learn( @@ -1094,6 +1240,13 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: return var_monitor def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ twin_critic = ['twin_critic_loss'] if self._twin_critic else [] alpha_loss = ['alpha_loss'] if self._auto_alpha else [] cos_similarity = ['cos_similarity'] if self._monitor_cos else [] diff --git a/ding/policy/td3.py b/ding/policy/td3.py index 632e4b3c22..7359190282 100644 --- a/ding/policy/td3.py +++ b/ding/policy/td3.py @@ -5,17 +5,11 @@ @POLICY_REGISTRY.register('td3') class TD3Policy(DDPGPolicy): - r""" + """ Overview: - Policy class of TD3 algorithm. - - Since DDPG and TD3 share many common things, we can easily derive this TD3 + Policy class of TD3 algorithm. Since DDPG and TD3 share many common things, we can easily derive this TD3 \ class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and noise in model wrapper. - - https://arxiv.org/pdf/1802.09477.pdf - - Property: - learn_mode, collect_mode, eval_mode + Paper link: https://arxiv.org/pdf/1802.09477.pdf Config: @@ -68,9 +62,7 @@ class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and n type='td3', # (bool) Whether to use cuda for network. cuda=False, - # (bool type) on_policy: Determine whether on-policy or off-policy. - # on-policy setting influences the behaviour of buffer. - # Default False in TD3. + # (bool) on_policy: Determine whether on-policy or off-policy. Default False in TD3. on_policy=False, # (bool) Whether use priority(priority sample, IS weight, update priority) # Default False in TD3. @@ -80,6 +72,8 @@ class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and n # (int) Number of training samples(randomly collected) in replay buffer when training starts. # Default 25000 in DDPG/TD3. random_collect_size=25000, + # (bool) Whether to need policy data in process transition. + transition_with_policy_data=False, # (str) Action space type action_space='continuous', # ['continuous', 'hybrid'] # (bool) Whether use batch normalization for reward @@ -92,9 +86,9 @@ class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and n # Default True for TD3, False for DDPG. twin_critic=True, ), + # learn_mode config learn=dict( - - # How many updates(iterations) to train after collector's one collection. + # (int) How many updates(iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. # collect data -> update policy-> collect data -> ... update_per_collect=1, @@ -112,7 +106,7 @@ class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and n # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), # when the episode step is greater than max episode step. ignore_done=False, - # (float type) target_theta: Used for soft update of the target network, + # (float) target_theta: Used for soft update of the target network, # aka. Interpolation factor in polyak averaging for target networks. # Default to 0.005. target_theta=0.005, @@ -130,30 +124,37 @@ class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and n noise_sigma=0.2, # (dict) Limit for range of target policy smoothing noise, aka. noise_clip. noise_range=dict( + # (int) min value of noise min=-0.5, + # (int) max value of noise max=0.5, ), ), + # collect_mode config collect=dict( + # (int) How many training samples collected in one collection procedure. + # Only one of [n_sample, n_episode] shoule be set. # n_sample=1, # (int) Cut trajectories into pieces with length "unroll_len". unroll_len=1, # (float) It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma". noise_sigma=0.1, ), - eval=dict( - evaluator=dict( - # (int) Evaluate every "eval_freq" training iterations. - eval_freq=5000, - ), - ), + eval=dict(), # for compability other=dict( replay_buffer=dict( - # (int) Maximum size of replay buffer. + # (int) Maximum size of replay buffer. Usually, larger buffer size is better. replay_buffer_size=100000, ), ), ) def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ return ["q_value", "loss", "lr", "entropy", "target_q_value", "td_error"]