Skip to content

Commit

Permalink
polish(rjy): polish comments in wqmix/ngu/pg model (#739)
Browse files Browse the repository at this point in the history
* polish(rjy): polish comments in wqmix model

* polish(rjy): polish comments in ngu model

* polish(rjy): polish comments in pg model

* polish(rjy): polish according to comments
  • Loading branch information
nighood authored Oct 31, 2023
1 parent 3034731 commit c005205
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 88 deletions.
64 changes: 35 additions & 29 deletions ding/model/template/ngu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@


def parallel_wrapper(forward_fn: Callable) -> Callable:
r"""
"""
Overview:
Process timestep T and batch_size B at the same time, in other words, treat different timestep data as
Process timestep T and batch_size B at the same time, in other words, treat different timestep data as \
different trajectories in a batch.
Arguments:
- forward_fn (:obj:`Callable`): Normal ``nn.Module`` 's forward function.
Expand Down Expand Up @@ -44,9 +44,12 @@ def reshape(d):
class NGU(nn.Module):
"""
Overview:
The recurrent Q model for NGU policy, modified from the class DRQN in q_leaning.py
input: x_t, a_{t-1}, r_e_{t-1}, r_i_{t-1}, beta
output:
The recurrent Q model for NGU(https://arxiv.org/pdf/2002.06038.pdf) policy, modified from the class DRQN in \
q_leaning.py. The implementation mentioned in the original paper is 'adapt the R2D2 agent that uses the \
dueling network architecture with an LSTM layer after a convolutional neural network'. The NGU network \
includes encoder, LSTM core(rnn) and head.
Interface:
``__init__``, ``forward``.
"""

def __init__(
Expand All @@ -62,20 +65,26 @@ def __init__(
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None
) -> None:
r"""
"""
Overview:
Init the DRQN Model according to arguments.
Init the DRQN Model for NGU according to arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
- action_shape (:obj:`Union[int, SequenceType]`): Action's space.
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``.
- lstm_type (:obj:`Optional[str]`): Version of rnn cell, now support ['normal', 'pytorch', 'hpc', 'gru']
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's space, such as 8 or [4, 84, 84].
- action_shape (:obj:`Union[int, SequenceType]`): Action's space, such as 6 or [2, 3, 3].
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``.
- collector_env_num (:obj:`Optional[int]`): The number of environments used to collect data simultaneously.
- dueling (:obj:`bool`): Whether choose ``DuelingHead`` (True) or ``DiscreteHead (False)``, \
default to True.
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``, should match the \
last element of ``encoder_hidden_size_list``.
- head_layer_num (:obj:`int`): The number of layers in head network.
- lstm_type (:obj:`Optional[str]`): Version of rnn cell, now support ['normal', 'pytorch', 'hpc', 'gru'], \
default is 'normal'.
- activation (:obj:`Optional[nn.Module]`):
The type of activation function to use in ``MLP`` the after ``layer_fn``,
if ``None`` then default set to ``nn.ReLU()``
The type of activation function to use in ``MLP`` the after ``layer_fn``, \
if ``None`` then default set to ``nn.ReLU()``.
- norm_type (:obj:`Optional[str]`):
The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details`
The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details`.
"""
super(NGU, self).__init__()
# For compatibility: 1, (1, ), [4, H, H]
Expand Down Expand Up @@ -122,32 +131,29 @@ def __init__(
def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: Optional[list] = None) -> Dict:
r"""
Overview:
Use observation, prev_action prev_reward_extrinsic to predict NGU Q output.
Parameter updates with NGU's MLPs forward setup.
Forward computation graph of NGU R2D2 network. Input observation, prev_action prev_reward_extrinsic \
to predict NGU Q output. Parameter updates with NGU's MLPs forward setup.
Arguments:
- inputs (:obj:`Dict`):
- inference: (:obj:'bool'): if inference is True, we unroll the one timestep transition,
- obs (:obj:`torch.Tensor`): Encoded observation.
- prev_state (:obj:`list`): Previous state's tensor of size ``(B, N)``.
- inference: (:obj:'bool'): If inference is True, we unroll the one timestep transition, \
if inference is False, we unroll the sequence transitions.
- saved_state_timesteps: (:obj:'Optional[list]'): when inference is False,
we unroll the sequence transitions, then we would save rnn hidden states at timesteps
- saved_state_timesteps: (:obj:'Optional[list]'): When inference is False, \
we unroll the sequence transitions, then we would save rnn hidden states at timesteps \
that are listed in list saved_state_timesteps.
ArgumentsKeys:
- obs (:obj:`torch.Tensor`): Encoded observation
- prev_state (:obj:`list`): Previous state's tensor of size ``(B, N)``
Returns:
- outputs (:obj:`Dict`):
Run ``MLP`` with ``DRQN`` setups and return the result prediction dictionary.
ReturnsKeys:
- logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``obs``.
- next_state (:obj:`list`): Next state's tensor of size ``(B, N)``
- next_state (:obj:`list`): Next state's tensor of size ``(B, N)``.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N=obs_space)`, where B is batch size.
- prev_state(:obj:`torch.FloatTensor list`): :math:`[(B, N)]`
- logit (:obj:`torch.FloatTensor`): :math:`(B, N)`
- next_state(:obj:`torch.FloatTensor list`): :math:`[(B, N)]`
- prev_state(:obj:`torch.FloatTensor list`): :math:`[(B, N)]`.
- logit (:obj:`torch.FloatTensor`): :math:`(B, N)`.
- next_state(:obj:`torch.FloatTensor list`): :math:`[(B, N)]`.
"""
x, prev_state = inputs['obs'], inputs['prev_state']
if 'prev_action' in inputs.keys():
Expand Down
44 changes: 44 additions & 0 deletions ding/model/template/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@

@MODEL_REGISTRY.register('pg')
class PG(nn.Module):
"""
Overview:
The neural network and computation graph of algorithms related to Policy Gradient(PG) \
(https://proceedings.neurips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf). \
The PG model is composed of two parts: encoder and head. Encoders are used to extract the feature \
from various observation. Heads are used to predict corresponding action logit.
Interface:
``__init__``, ``forward``.
"""

def __init__(
self,
Expand All @@ -23,6 +32,31 @@ def __init__(
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None
) -> None:
"""
Overview:
Initialize the PG model according to corresponding input arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
- action_space (:obj:`str`): The type of different action spaces, including ['discrete', 'continuous'], \
then will instantiate corresponding head, including ``DiscreteHead`` and ``ReparameterizationHead``.
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
the last element must match ``head_hidden_size``.
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``head`` network, defaults \
to None, it must match the last element of ``encoder_hidden_size_list``.
- head_layer_num (:obj:`int`): The num of layers used in the ``head`` network to compute action.
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
if ``None`` then default set it to ``nn.ReLU()``.
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
Examples:
>>> model = PG((4, 84, 84), 5)
>>> inputs = torch.randn(8, 4, 84, 84)
>>> outputs = model(inputs)
>>> assert isinstance(outputs, dict)
>>> assert outputs['logit'].shape == (8, 5)
>>> assert outputs['dist'].sample().shape == (8, )
"""
super(PG, self).__init__()
# For compatibility: 1, (1, ), [4, 32, 32]
obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
Expand Down Expand Up @@ -57,6 +91,16 @@ def __init__(
raise KeyError("not support action space: {}".format(self.action_space))

def forward(self, x: torch.Tensor) -> Dict:
"""
Overview:
PG forward computation graph, input observation tensor to predict policy distribution.
Arguments:
- x (:obj:`torch.Tensor`): The input observation tensor data.
Returns:
- outputs (:obj:`torch.distributions`): The output policy distribution. If action space is \
discrete, the output is Categorical distribution; if action space is continuous, the output is Normal \
distribution.
"""
x = self.encoder(x)
x = self.head(x)
if self.action_space == 'discrete':
Expand Down
Loading

0 comments on commit c005205

Please sign in to comment.