Skip to content

Commit

Permalink
polish code
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Nov 23, 2023
1 parent 3687f8b commit 4fb85b0
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 10 deletions.
3 changes: 2 additions & 1 deletion ding/envs/env_manager/envpool_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ def reset(self) -> None:
obs = obs.astype(np.float32)
if self._cfg.image_observation:
obs /= 255.0
ready_obs = deep_merge_dicts({i: o for i, o in zip(env_id, obs)}, ready_obs)
for i in range(len(list(env_id))):
ready_obs[env_id[i]] = obs[i]
if len(ready_obs) == self._env_num:
break
self._eval_episode_return = [0. for _ in range(self._env_num)]
Expand Down
52 changes: 43 additions & 9 deletions ding/envs/env_manager/tests/test_envpool_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from easydict import EasyDict

from ..envpool_env_manager import PoolEnvManager
from ding.envs.env_manager.envpool_env_manager import PoolEnvManager, PoolEnvManagerV2

env_num_args = [[16, 8], [8, 8]]

Expand All @@ -30,17 +30,51 @@ def test_naive(self, env_num, batch_size):
env_manager = PoolEnvManager(env_manager_cfg)
assert env_manager._closed
env_manager.launch()
# Test step
start_time = time.time()
for count in range(20):
for count in range(5):
env_id = env_manager.ready_obs.keys()
action = {i: np.random.randint(4) for i in env_id}
timestep = env_manager.step(action)
assert len(timestep) == env_manager_cfg.batch_size
print('Count {}'.format(count))
print([v.info for v in timestep.values()])
end_time = time.time()
print('total step time: {}'.format(end_time - start_time))
# Test close
env_manager.close()
assert env_manager._closed


@pytest.mark.envpooltest
@pytest.mark.parametrize('env_num, batch_size', env_num_args)
class TestPoolEnvManagerV2:

def test_naive(self, env_num, batch_size):
env_manager_cfg = EasyDict(
{
'env_id': 'Pong-v5',
'env_num': env_num,
'batch_size': batch_size,
'seed': 3,
# env wrappers
'episodic_life': False,
'reward_clip': False,
'gray_scale': True,
'stack_num': 4,
'frame_skip': 4,
}
)
env_manager = PoolEnvManagerV2(env_manager_cfg)
assert env_manager._closed
ready_obs = env_manager.launch()
env_id = list(ready_obs.keys())
for count in range(5):
action = {i: np.random.randint(4) for i in env_id}
action_send = np.array([action[i] for i in action.keys()])
env_id_send = np.array(list(action.keys()))
env_manager.send_action(action_send, env_id_send)
next_obs, rew, done, info = env_manager.receive_data()
assert next_obs.shape == (env_manager_cfg.batch_size, 4, 84, 84)
assert rew.shape == (env_manager_cfg.batch_size, )
assert done.shape == (env_manager_cfg.batch_size, )
assert info['env_id'].shape == (env_manager_cfg.batch_size, )
env_manager.close()
assert env_manager._closed


if __name__ == "__main__":
TestPoolEnvManagerV2().test_naive(16, 8)

0 comments on commit 4fb85b0

Please sign in to comment.