Skip to content

Commit

Permalink
feature(pu): add pistonball_env, its unittest and qmix config (#833)
Browse files Browse the repository at this point in the history
* feature(pu): add pistonball_env, its unittest and qmix config

* polish(pu): pistonball reuse PTZRecordVideo

* polish(pu): adapt qmix's mixer to support image obs

* fix(pu): fix qmix's mixer to support image obs

* sync code

* polish(pu): polish ptz_pistonball_qmix_config.py

* polish(pu): polish qmix.py

* polish(pu): add normalize_reward in pistonball_env

* polish(pu): polish hyper-parameters in ptz_pistonball_qmix_config.py

* polish(pu): polish ptz_pistonball_qmix_config.py

* style(pu): yapf format

* polish(pu): polish comments in qmix

* polish(pu): polish qmix comments
  • Loading branch information
puyuan1996 authored Nov 25, 2024
1 parent 1f198e9 commit 1158cd5
Show file tree
Hide file tree
Showing 7 changed files with 539 additions and 23 deletions.
80 changes: 70 additions & 10 deletions ding/model/template/qmix.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Union, List
from functools import reduce
from typing import List, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce
from ding.utils import list_split, MODEL_REGISTRY
from ding.torch_utils import fc_block, MLP
from ding.torch_utils import MLP, fc_block
from ding.utils import MODEL_REGISTRY, list_split

from ..common import ConvEncoder
from .q_learning import DRQN


Expand Down Expand Up @@ -111,7 +114,7 @@ def __init__(
self,
agent_num: int,
obs_shape: int,
global_obs_shape: int,
global_obs_shape: Union[int, List[int]],
action_shape: int,
hidden_size_list: list,
mixer: bool = True,
Expand Down Expand Up @@ -146,8 +149,34 @@ def __init__(
embedding_size = hidden_size_list[-1]
self.mixer = mixer
if self.mixer:
self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation)
self._global_state_encoder = nn.Identity()
global_obs_shape_type = self._get_global_obs_shape_type(global_obs_shape)

if global_obs_shape_type == "flat":
self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation)
self._global_state_encoder = nn.Identity()
elif global_obs_shape_type == "image":
self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation)
self._global_state_encoder = ConvEncoder(
global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN'
)
else:
raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}")

def _get_global_obs_shape_type(self, global_obs_shape: Union[int, List[int]]) -> str:
"""
Overview:
Determine the type of global observation shape.
Arguments:
- global_obs_shape (:obj:`Union[int, List[int]]`): The global observation state.
Returns:
- obs_shape_type (:obj:`str`): 'flat' for 1D observation or 'image' for 3D observation.
"""
if isinstance(global_obs_shape, int) or (isinstance(global_obs_shape, list) and len(global_obs_shape) == 1):
return "flat"
elif isinstance(global_obs_shape, list) and len(global_obs_shape) == 3:
return "image"
else:
raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}")

def forward(self, data: dict, single_step: bool = True) -> dict:
"""
Expand Down Expand Up @@ -182,8 +211,16 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[
'prev_state']
action = data.get('action', None)
# If single_step is True, add a new dimension at the front of agent_state
# This is necessary to maintain the expected input shape for the model,
# which requires a time step dimension even when processing a single step.
if single_step:
agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0)
agent_state = agent_state.unsqueeze(0)
# If single_step is True and global_state has 2 dimensions, add a new dimension at the front of global_state
# This ensures that global_state has the same number of dimensions as agent_state,
# allowing for consistent processing in the forward computation.
if single_step and len(global_state.shape) == 2:
global_state = global_state.unsqueeze(0)
T, B, A = agent_state.shape[:3]
assert len(prev_state) == B and all(
[len(p) == A for p in prev_state]
Expand All @@ -205,15 +242,38 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1))
agent_q_act = agent_q_act.squeeze(-1) # T, B, A
if self.mixer:
global_state_embedding = self._global_state_encoder(global_state)
global_state_embedding = self._process_global_state(global_state)
total_q = self._mixer(agent_q_act, global_state_embedding)
else:
total_q = agent_q_act.sum(-1)
total_q = agent_q_act.sum(dim=-1)

if single_step:
total_q, agent_q = total_q.squeeze(0), agent_q.squeeze(0)

return {
'total_q': total_q,
'logit': agent_q,
'next_state': next_state,
'action_mask': data['obs']['action_mask']
}

def _process_global_state(self, global_state: torch.Tensor) -> torch.Tensor:
"""
Overview:
Process the global state to obtain an embedding.
Arguments:
- global_state (:obj:`torch.Tensor`): The global state tensor.
Returns:
- global_state_embedding (:obj:`torch.Tensor`): The processed global state embedding.
"""
# If global_state has 5 dimensions, it's likely in the form [batch_size, time_steps, C, H, W]
if global_state.dim() == 5:
# Reshape and apply the global state encoder
batch_time_shape = global_state.shape[:2] # [batch_size, time_steps]
reshaped_state = global_state.view(-1, *global_state.shape[-3:]) # Collapse batch and time dims
encoded_state = self._global_state_encoder(reshaped_state)
return encoded_state.view(*batch_time_shape, -1) # Reshape back to [batch_size, time_steps, embedding_dim]
else:
# For lower-dimensional states, apply the encoder directly
return self._global_state_encoder(global_state)
31 changes: 31 additions & 0 deletions ding/model/template/tests/test_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,34 @@ def test_qmix():
is_differentiable(loss, qmix_model)
data.pop('action')
output = qmix_model(data, single_step=False)


@pytest.mark.unittest
def test_qmix_process_global_state():
# Test the behavior of the _process_global_state method with different global_obs_shape types
agent_num, obs_dim, global_obs_dim, action_dim = 4, 32, 32 * 4, 9
embedding_dim = 64

# Case 1: Test "flat" type global_obs_shape
global_obs_shape = global_obs_dim # Flat global_obs_shape
qmix_model_flat = QMix(agent_num, obs_dim, global_obs_shape, action_dim, [64, 128, embedding_dim], mixer=True)

# Simulate input for the "flat" type global_state
batch_size, time_steps = 3, 8
global_state_flat = torch.randn(batch_size, time_steps, global_obs_dim)
processed_flat = qmix_model_flat._process_global_state(global_state_flat)

# Ensure the output shape is correct [batch_size, time_steps, embedding_dim]
assert processed_flat.shape == (batch_size, time_steps, global_obs_dim)

# Case 2: Test "image" type global_obs_shape
global_obs_shape = [3, 64, 64] # Image-shaped global_obs_shape (C, H, W)
qmix_model_image = QMix(agent_num, obs_dim, global_obs_shape, action_dim, [64, 128, embedding_dim], mixer=True)

# Simulate input for the "image" type global_state
C, H, W = global_obs_shape
global_state_image = torch.randn(batch_size, time_steps, C, H, W)
processed_image = qmix_model_image._process_global_state(global_state_image)

# Ensure the output shape is correct [batch_size, time_steps, embedding_dim]
assert processed_image.shape == (batch_size, time_steps, embedding_dim)
79 changes: 79 additions & 0 deletions dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from easydict import EasyDict

n_pistons = 20
collector_env_num = 8
evaluator_env_num = 8
max_env_step = 3e6

main_config = dict(
exp_name=f'data_pistonball/ptz_pistonball_n{n_pistons}_qmix_seed0',
env=dict(
env_family='butterfly',
env_id='pistonball_v6',
n_pistons=n_pistons,
max_cycles=125,
agent_obs_only=False,
continuous_actions=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
stop_value=1e6,
manager=dict(shared_memory=False,),
),
policy=dict(
cuda=True,
model=dict(
agent_num=n_pistons,
obs_shape=(3, 457, 120), # RGB image observation shape for each piston agent
global_obs_shape=(3, 560, 880), # Global state shape
action_shape=3, # Discrete actions (0, 1, 2)
hidden_size_list=[32, 64, 128, 256],
mixer=True,
),
learn=dict(
update_per_collect=20,
batch_size=32,
learning_rate=0.0001,
clip_value=5,
target_update_theta=0.001,
discount_factor=0.99,
double_q=True,
),
collect=dict(
n_sample=16,
unroll_len=5,
env_num=collector_env_num,
),
eval=dict(env_num=evaluator_env_num),
other=dict(
eps=dict(
type='exp',
start=1.0,
end=0.05,
decay=100000,
),
replay_buffer=dict(
replay_buffer_size=5000,
),
),
),
)
main_config = EasyDict(main_config)

create_config = dict(
env=dict(
import_names=['dizoo.petting_zoo.envs.petting_zoo_pistonball_env'],
type='petting_zoo_pistonball',
),
env_manager=dict(type='subprocess'),
policy=dict(type='qmix'),
)
create_config = EasyDict(create_config)

ptz_pistonball_qmix_config = main_config
ptz_pistonball_qmix_create_config = create_config

if __name__ == '__main__':
# or you can enter `ding -m serial -c ptz_pistonball_qmix_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0, max_env_step=max_env_step)
Loading

0 comments on commit 1158cd5

Please sign in to comment.