From 1158cd59f7342687b9ff81d7c97e3f7222a4f9a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <48008469+puyuan1996@users.noreply.github.com> Date: Mon, 25 Nov 2024 22:50:55 +0800 Subject: [PATCH] feature(pu): add pistonball_env, its unittest and qmix config (#833) * feature(pu): add pistonball_env, its unittest and qmix config * polish(pu): pistonball reuse PTZRecordVideo * polish(pu): adapt qmix's mixer to support image obs * fix(pu): fix qmix's mixer to support image obs * sync code * polish(pu): polish ptz_pistonball_qmix_config.py * polish(pu): polish qmix.py * polish(pu): add normalize_reward in pistonball_env * polish(pu): polish hyper-parameters in ptz_pistonball_qmix_config.py * polish(pu): polish ptz_pistonball_qmix_config.py * style(pu): yapf format * polish(pu): polish comments in qmix * polish(pu): polish qmix comments --- ding/model/template/qmix.py | 80 +++++- ding/model/template/tests/test_qmix.py | 31 +++ .../config/ptz_pistonball_qmix_config.py | 79 ++++++ .../envs/petting_zoo_pistonball_env.py | 244 ++++++++++++++++++ .../envs/petting_zoo_simple_spread_env.py | 20 +- .../envs/test_petting_zoo_pistonball_env.py | 106 ++++++++ setup.py | 2 +- 7 files changed, 539 insertions(+), 23 deletions(-) create mode 100644 dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py create mode 100644 dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py create mode 100644 dizoo/petting_zoo/envs/test_petting_zoo_pistonball_env.py diff --git a/ding/model/template/qmix.py b/ding/model/template/qmix.py index 68354e0cf7..b5cde0806b 100644 --- a/ding/model/template/qmix.py +++ b/ding/model/template/qmix.py @@ -1,10 +1,13 @@ -from typing import Union, List +from functools import reduce +from typing import List, Union + import torch import torch.nn as nn import torch.nn.functional as F -from functools import reduce -from ding.utils import list_split, MODEL_REGISTRY -from ding.torch_utils import fc_block, MLP +from ding.torch_utils import MLP, fc_block +from ding.utils import MODEL_REGISTRY, list_split + +from ..common import ConvEncoder from .q_learning import DRQN @@ -111,7 +114,7 @@ def __init__( self, agent_num: int, obs_shape: int, - global_obs_shape: int, + global_obs_shape: Union[int, List[int]], action_shape: int, hidden_size_list: list, mixer: bool = True, @@ -146,8 +149,34 @@ def __init__( embedding_size = hidden_size_list[-1] self.mixer = mixer if self.mixer: - self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation) - self._global_state_encoder = nn.Identity() + global_obs_shape_type = self._get_global_obs_shape_type(global_obs_shape) + + if global_obs_shape_type == "flat": + self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation) + self._global_state_encoder = nn.Identity() + elif global_obs_shape_type == "image": + self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation) + self._global_state_encoder = ConvEncoder( + global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN' + ) + else: + raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}") + + def _get_global_obs_shape_type(self, global_obs_shape: Union[int, List[int]]) -> str: + """ + Overview: + Determine the type of global observation shape. + Arguments: + - global_obs_shape (:obj:`Union[int, List[int]]`): The global observation state. + Returns: + - obs_shape_type (:obj:`str`): 'flat' for 1D observation or 'image' for 3D observation. + """ + if isinstance(global_obs_shape, int) or (isinstance(global_obs_shape, list) and len(global_obs_shape) == 1): + return "flat" + elif isinstance(global_obs_shape, list) and len(global_obs_shape) == 3: + return "image" + else: + raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}") def forward(self, data: dict, single_step: bool = True) -> dict: """ @@ -182,8 +211,16 @@ def forward(self, data: dict, single_step: bool = True) -> dict: agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[ 'prev_state'] action = data.get('action', None) + # If single_step is True, add a new dimension at the front of agent_state + # This is necessary to maintain the expected input shape for the model, + # which requires a time step dimension even when processing a single step. if single_step: - agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0) + agent_state = agent_state.unsqueeze(0) + # If single_step is True and global_state has 2 dimensions, add a new dimension at the front of global_state + # This ensures that global_state has the same number of dimensions as agent_state, + # allowing for consistent processing in the forward computation. + if single_step and len(global_state.shape) == 2: + global_state = global_state.unsqueeze(0) T, B, A = agent_state.shape[:3] assert len(prev_state) == B and all( [len(p) == A for p in prev_state] @@ -205,15 +242,38 @@ def forward(self, data: dict, single_step: bool = True) -> dict: agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1)) agent_q_act = agent_q_act.squeeze(-1) # T, B, A if self.mixer: - global_state_embedding = self._global_state_encoder(global_state) + global_state_embedding = self._process_global_state(global_state) total_q = self._mixer(agent_q_act, global_state_embedding) else: - total_q = agent_q_act.sum(-1) + total_q = agent_q_act.sum(dim=-1) + if single_step: total_q, agent_q = total_q.squeeze(0), agent_q.squeeze(0) + return { 'total_q': total_q, 'logit': agent_q, 'next_state': next_state, 'action_mask': data['obs']['action_mask'] } + + def _process_global_state(self, global_state: torch.Tensor) -> torch.Tensor: + """ + Overview: + Process the global state to obtain an embedding. + Arguments: + - global_state (:obj:`torch.Tensor`): The global state tensor. + + Returns: + - global_state_embedding (:obj:`torch.Tensor`): The processed global state embedding. + """ + # If global_state has 5 dimensions, it's likely in the form [batch_size, time_steps, C, H, W] + if global_state.dim() == 5: + # Reshape and apply the global state encoder + batch_time_shape = global_state.shape[:2] # [batch_size, time_steps] + reshaped_state = global_state.view(-1, *global_state.shape[-3:]) # Collapse batch and time dims + encoded_state = self._global_state_encoder(reshaped_state) + return encoded_state.view(*batch_time_shape, -1) # Reshape back to [batch_size, time_steps, embedding_dim] + else: + # For lower-dimensional states, apply the encoder directly + return self._global_state_encoder(global_state) diff --git a/ding/model/template/tests/test_qmix.py b/ding/model/template/tests/test_qmix.py index ce1817b697..74062357e7 100644 --- a/ding/model/template/tests/test_qmix.py +++ b/ding/model/template/tests/test_qmix.py @@ -43,3 +43,34 @@ def test_qmix(): is_differentiable(loss, qmix_model) data.pop('action') output = qmix_model(data, single_step=False) + + +@pytest.mark.unittest +def test_qmix_process_global_state(): + # Test the behavior of the _process_global_state method with different global_obs_shape types + agent_num, obs_dim, global_obs_dim, action_dim = 4, 32, 32 * 4, 9 + embedding_dim = 64 + + # Case 1: Test "flat" type global_obs_shape + global_obs_shape = global_obs_dim # Flat global_obs_shape + qmix_model_flat = QMix(agent_num, obs_dim, global_obs_shape, action_dim, [64, 128, embedding_dim], mixer=True) + + # Simulate input for the "flat" type global_state + batch_size, time_steps = 3, 8 + global_state_flat = torch.randn(batch_size, time_steps, global_obs_dim) + processed_flat = qmix_model_flat._process_global_state(global_state_flat) + + # Ensure the output shape is correct [batch_size, time_steps, embedding_dim] + assert processed_flat.shape == (batch_size, time_steps, global_obs_dim) + + # Case 2: Test "image" type global_obs_shape + global_obs_shape = [3, 64, 64] # Image-shaped global_obs_shape (C, H, W) + qmix_model_image = QMix(agent_num, obs_dim, global_obs_shape, action_dim, [64, 128, embedding_dim], mixer=True) + + # Simulate input for the "image" type global_state + C, H, W = global_obs_shape + global_state_image = torch.randn(batch_size, time_steps, C, H, W) + processed_image = qmix_model_image._process_global_state(global_state_image) + + # Ensure the output shape is correct [batch_size, time_steps, embedding_dim] + assert processed_image.shape == (batch_size, time_steps, embedding_dim) diff --git a/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py b/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py new file mode 100644 index 0000000000..3816db6ef5 --- /dev/null +++ b/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py @@ -0,0 +1,79 @@ +from easydict import EasyDict + +n_pistons = 20 +collector_env_num = 8 +evaluator_env_num = 8 +max_env_step = 3e6 + +main_config = dict( + exp_name=f'data_pistonball/ptz_pistonball_n{n_pistons}_qmix_seed0', + env=dict( + env_family='butterfly', + env_id='pistonball_v6', + n_pistons=n_pistons, + max_cycles=125, + agent_obs_only=False, + continuous_actions=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + stop_value=1e6, + manager=dict(shared_memory=False,), + ), + policy=dict( + cuda=True, + model=dict( + agent_num=n_pistons, + obs_shape=(3, 457, 120), # RGB image observation shape for each piston agent + global_obs_shape=(3, 560, 880), # Global state shape + action_shape=3, # Discrete actions (0, 1, 2) + hidden_size_list=[32, 64, 128, 256], + mixer=True, + ), + learn=dict( + update_per_collect=20, + batch_size=32, + learning_rate=0.0001, + clip_value=5, + target_update_theta=0.001, + discount_factor=0.99, + double_q=True, + ), + collect=dict( + n_sample=16, + unroll_len=5, + env_num=collector_env_num, + ), + eval=dict(env_num=evaluator_env_num), + other=dict( + eps=dict( + type='exp', + start=1.0, + end=0.05, + decay=100000, + ), + replay_buffer=dict( + replay_buffer_size=5000, + ), + ), + ), +) +main_config = EasyDict(main_config) + +create_config = dict( + env=dict( + import_names=['dizoo.petting_zoo.envs.petting_zoo_pistonball_env'], + type='petting_zoo_pistonball', + ), + env_manager=dict(type='subprocess'), + policy=dict(type='qmix'), +) +create_config = EasyDict(create_config) + +ptz_pistonball_qmix_config = main_config +ptz_pistonball_qmix_create_config = create_config + +if __name__ == '__main__': + # or you can enter `ding -m serial -c ptz_pistonball_qmix_config.py -s 0` + from ding.entry import serial_pipeline + serial_pipeline((main_config, create_config), seed=0, max_env_step=max_env_step) \ No newline at end of file diff --git a/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py b/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py new file mode 100644 index 0000000000..775af37d6a --- /dev/null +++ b/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py @@ -0,0 +1,244 @@ +import copy +from functools import reduce +from typing import Dict, List, Optional + +import gymnasium as gym +import numpy as np +from ding.envs import BaseEnv, BaseEnvTimestep +from ding.envs.common.common_function import affine_transform +from ding.torch_utils import to_ndarray +from ding.utils import ENV_REGISTRY +from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PTZRecordVideo +from pettingzoo.butterfly import pistonball_v6 + + +@ENV_REGISTRY.register('petting_zoo_pistonball') +class PettingZooPistonballEnv(BaseEnv): + """ + DI-engine PettingZoo environment adapter for the Pistonball environment. + This class integrates the `pistonball_v6` environment into the DI-engine + framework, supporting both continuous and discrete actions. + """ + + def __init__(self, cfg: dict) -> None: + self._cfg = cfg + self._init_flag = False + self._replay_path = None + self._num_pistons = self._cfg.get('n_pistons', 20) + self._continuous_actions = self._cfg.get('continuous_actions', False) + self._max_cycles = self._cfg.get('max_cycles', 125) + self._act_scale = self._cfg.get('act_scale', False) + self._agent_specific_global_state = self._cfg.get('agent_specific_global_state', False) + if self._act_scale: + assert self._continuous_actions, 'Action scaling only applies to continuous action spaces.' + self._channel_first = self._cfg.get('channel_first', True) + self.normalize_reward = self._cfg.normalize_reward + + def reset(self) -> np.ndarray: + """ + Resets the environment and returns the initial observations. + """ + if not self._init_flag: + # Initialize the pistonball environment + parallel_env = pistonball_v6.parallel_env + self._env = parallel_env( + n_pistons=self._num_pistons, + continuous=self._continuous_actions, + max_cycles=self._max_cycles + ) + self._env.reset() + self._agents = self._env.agents + + # Define action and observation spaces + self._action_space = gym.spaces.Dict({agent: self._env.action_space(agent) for agent in self._agents}) + single_agent_obs_space = self._env.observation_space(self._agents[0]) + single_agent_action_space = self._env.action_space(self._agents[0]) + + if isinstance(single_agent_action_space, gym.spaces.Box): + self._action_dim = single_agent_action_space.shape + elif isinstance(single_agent_action_space, gym.spaces.Discrete): + self._action_dim = (single_agent_action_space.n, ) + else: + raise Exception('Only support `Box` or `Discrete` obs space for single agent.') + + if isinstance(single_agent_obs_space, gym.spaces.Box): + self._obs_shape = single_agent_obs_space.shape + else: + raise ValueError("Only support `Box` observation space for each agent.") + + self._observation_space = gym.spaces.Box( + low=0, high=255, shape=(self._num_pistons, *self._obs_shape), dtype=np.uint8 + ) + + self._reward_space = gym.spaces.Dict( + { + agent: gym.spaces.Box(low=float('-inf'), high=float('inf'), shape=(1,), dtype=np.float32) + for agent in self._agents + } + ) + + if self._replay_path is not None: + self._env.render_mode = 'rgb_array' + self._env = PTZRecordVideo(self._env, self._replay_path, name_prefix=f'rl-video-{id(self)}', disable_logger=True) + self._init_flag = True + + if hasattr(self, '_seed'): + obs = self._env.reset(seed=self._seed) + else: + obs = self._env.reset() + + self._eval_episode_return = 0.0 + self._step_count = 0 + obs_n = self._process_obs(obs) + return obs_n + + def close(self) -> None: + """ + Closes the environment. + """ + if self._init_flag: + self._env.close() + self._init_flag = False + + def render(self) -> None: + """ + Renders the environment. + """ + self._env.render() + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + """ + Sets the seed for the environment. + """ + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def step(self, action: np.ndarray) -> BaseEnvTimestep: + """ + Steps through the environment using the provided action. + """ + self._step_count += 1 + assert isinstance(action, np.ndarray), type(action) + action = self._process_action(action) + if self._act_scale: + for agent in self._agents: + action[agent] = affine_transform(action[agent], min_val=self.action_space[agent].low, max_val=self.action_space[agent].high) + + obs, rew, done, trunc, info = self._env.step(action) + obs_n = self._process_obs(obs) + rew_n = np.array([sum([rew[agent] for agent in self._agents])]) + rew_n = rew_n.astype(np.float32) + + if self.normalize_reward: + # TODO: more elegant scale factor + rew_n = rew_n / (self._num_pistons*50) + + self._eval_episode_return += rew_n.item() + + done_n = reduce(lambda x, y: x and y, done.values()) or self._step_count >= self._max_cycles + if done_n: + info['eval_episode_return'] = self._eval_episode_return + + + return BaseEnvTimestep(obs_n, rew_n, done_n, info) + + def enable_save_replay(self, replay_path: Optional[str] = None) -> None: + """ + Enables video recording during the episode. + """ + if replay_path is None: + replay_path = './video' + self._replay_path = replay_path + + def _process_obs(self, obs: Dict[str, np.ndarray]) -> np.ndarray: + """ + Processes the observations into the required format. + """ + # Process agent observations, transpose if channel_first is True + obs = np.array( + [np.transpose(obs[agent], (2, 0, 1)) if self._channel_first else obs[agent] + for agent in self._agents], + dtype=np.uint8 + ) + + # Return only agent observations if configured to do so + if self._cfg.get('agent_obs_only', False): + return obs + + # Initialize return dictionary + ret = { + 'agent_state': (obs / 255.0).astype(np.float32) + } + + # Obtain global state, transpose if channel_first is True + global_state = self._env.state() + if self._channel_first: + global_state = global_state.transpose(2, 0, 1) + ret['global_state'] = (global_state / 255.0).astype(np.float32) + + # Handle agent-specific global states by repeating the global state for each agent + if self._agent_specific_global_state: + ret['global_state'] = np.tile( + np.expand_dims(ret['global_state'], axis=0), + (self._num_pistons, 1, 1, 1) + ) + + # Set action mask for each agent + ret['action_mask'] = np.ones((self._num_pistons, *self._action_dim), dtype=np.float32) + + return ret + + def _process_action(self, action: np.ndarray) -> Dict[str, np.ndarray]: + """ + Processes the action array into a dictionary format for each agent. + """ + dict_action = {} + for i, agent in enumerate(self._agents): + dict_action[agent] = action[i] + return dict_action + + def random_action(self) -> np.ndarray: + """ + Generates a random action for each agent. + """ + random_action = self.action_space.sample() + for k in random_action: + if isinstance(random_action[k], np.ndarray): + pass + elif isinstance(random_action[k], int): + random_action[k] = to_ndarray([random_action[k]], dtype=np.int64) + return random_action + + @property + def agents(self) -> List[str]: + return self._agents + + @property + def observation_space(self) -> gym.spaces.Space: + return self._observation_space + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space + + @property + def reward_space(self) -> gym.spaces.Space: + return self._reward_space + + @staticmethod + def create_collector_env_cfg(cfg: dict) -> List[dict]: + collector_env_num = cfg.pop('collector_env_num') + cfg = copy.deepcopy(cfg) + cfg.normalize_reward = True + return [cfg for _ in range(collector_env_num)] + + @staticmethod + def create_evaluator_env_cfg(cfg: dict) -> List[dict]: + evaluator_env_num = cfg.pop('evaluator_env_num') + cfg = copy.deepcopy(cfg) + cfg.normalize_reward = False + return [cfg for _ in range(evaluator_env_num)] + + def __repr__(self) -> str: + return "DI-engine PettingZoo Pistonball Env" \ No newline at end of file diff --git a/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index 4be9687fe8..4f5916f33b 100644 --- a/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -13,18 +13,13 @@ from pettingzoo.mpe.simple_spread.simple_spread import Scenario +# Custom wrapper for recording videos in PettingZoo environments class PTZRecordVideo(gym.wrappers.RecordVideo): def step(self, action): """Steps through the environment using action, recording observations if :attr:`self.recording`.""" # gymnasium==0.27.1 - ( - observations, - rewards, - terminateds, - truncateds, - infos, - ) = self.env.step(action) + observations, rewards, terminateds, truncateds, infos = self.env.step(action) # Because pettingzoo returns a dict of terminated and truncated, we need to check if any of the values are True if not (self.terminated is True or self.truncated is True): # the first location for modifications @@ -40,6 +35,7 @@ def step(self, action): self.terminated = terminateds[0] self.truncated = truncateds[0] + # Capture the video frame if recording if self.recording: assert self.video_recorder is not None self.video_recorder.capture_frame() @@ -102,11 +98,11 @@ def reset(self) -> np.ndarray: self._agents = self._env.agents self._action_space = gym.spaces.Dict({agent: self._env.action_space(agent) for agent in self._agents}) - single_agent_obs_space = self._env.action_space(self._agents[0]) - if isinstance(single_agent_obs_space, gym.spaces.Box): - self._action_dim = single_agent_obs_space.shape - elif isinstance(single_agent_obs_space, gym.spaces.Discrete): - self._action_dim = (single_agent_obs_space.n, ) + single_agent_action_space = self._env.action_space(self._agents[0]) + if isinstance(single_agent_action_space, gym.spaces.Box): + self._action_dim = single_agent_action_space.shape + elif isinstance(single_agent_action_space, gym.spaces.Discrete): + self._action_dim = (single_agent_action_space.n, ) else: raise Exception('Only support `Box` or `Discrete` obs space for single agent.') diff --git a/dizoo/petting_zoo/envs/test_petting_zoo_pistonball_env.py b/dizoo/petting_zoo/envs/test_petting_zoo_pistonball_env.py new file mode 100644 index 0000000000..ea5ac988d7 --- /dev/null +++ b/dizoo/petting_zoo/envs/test_petting_zoo_pistonball_env.py @@ -0,0 +1,106 @@ +from easydict import EasyDict +import pytest +import numpy as np +from dizoo.petting_zoo.envs.petting_zoo_pistonball_env import PettingZooPistonballEnv + + +@pytest.mark.envtest +class TestPettingZooPistonballEnv: + + def test_agent_obs_only(self): + n_pistons = 20 + env = PettingZooPistonballEnv( + EasyDict( + dict( + n_pistons=n_pistons, + max_cycles=125, + agent_obs_only=True, + continuous_actions=True, + act_scale=False, + ) + ) + ) + env.seed(123) + assert env._seed == 123 + obs = env.reset() + assert obs.shape == (n_pistons, 3, 457, 120) + for i in range(10): + random_action = env.random_action() + random_action = np.array([random_action[agent] for agent in random_action]) + timestep = env.step(random_action) + # print(timestep) + assert isinstance(timestep.obs, np.ndarray), timestep.obs + assert timestep.obs.shape == (n_pistons, 3, 457, 120) + assert isinstance(timestep.done, bool), timestep.done + assert isinstance(timestep.reward, np.ndarray), timestep.reward + assert timestep.reward.dtype == np.float32 + print(env.observation_space, env.action_space, env.reward_space) + env.close() + + def test_dict_obs(self): + n_pistons = 20 + env = PettingZooPistonballEnv( + EasyDict( + dict( + n_pistons=n_pistons, + max_cycles=125, + agent_obs_only=False, + agent_specific_global_state=False, + continuous_actions=True, + act_scale=False, + ) + ) + ) + env.seed(123) + assert env._seed == 123 + obs = env.reset() + for k, v in obs.items(): + print(k, v.shape) + for i in range(10): + random_action = env.random_action() + random_action = np.array([random_action[agent] for agent in random_action]) + timestep = env.step(random_action) + # print(timestep) + assert isinstance(timestep.obs['agent_state'], np.ndarray), timestep.obs['agent_state'] + assert isinstance(timestep.obs['global_state'], np.ndarray), timestep.obs['global_state'] + assert timestep.obs['agent_state'].shape == (n_pistons, 3, 457, 120) + assert timestep.obs['global_state'].shape == (3, 560, 880) + assert isinstance(timestep.done, bool), timestep.done + assert isinstance(timestep.reward, np.ndarray), timestep.reward + print(env.observation_space, env.action_space, env.reward_space) + env.close() + + def test_agent_specific_global_state(self): + n_pistons = 20 + env = PettingZooPistonballEnv( + EasyDict( + dict( + n_pistons=n_pistons, + max_cycles=125, + agent_obs_only=False, + continuous_actions=True, + agent_specific_global_state=True, + act_scale=False, + ) + ) + ) + env.seed(123) + assert env._seed == 123 + obs = env.reset() + for k, v in obs.items(): + print(k, v.shape) + for i in range(10): + random_action = env.random_action() + random_action = np.array([random_action[agent] for agent in random_action]) + timestep = env.step(random_action) + # print(timestep) + assert isinstance(timestep.obs['agent_state'], np.ndarray), timestep.obs['agent_state'] + assert isinstance(timestep.obs['global_state'], np.ndarray), timestep.obs['global_state'] + assert timestep.obs['agent_state'].shape == (n_pistons, 3, 457, 120) + assert timestep.obs['global_state'].shape == (n_pistons, 3, 560, 880) + assert timestep.obs['global_state'].shape == (n_pistons, 3, 560, 880) + + assert isinstance(timestep.done, bool), timestep.done + assert isinstance(timestep.reward, np.ndarray), timestep.reward + print(env.observation_space, env.action_space, env.reward_space) + env.close() \ No newline at end of file diff --git a/setup.py b/setup.py index ee18a4c415..59dd9ce782 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ 'responses', # interaction 'URLObject', # interaction 'pynng', # parallel - 'sniffio', # parallel + 'sniffio', # parallel 'redis', # parallel 'mpire>=2.3.5', # parallel ],