Skip to content

Commit

Permalink
Fix the implementation to make the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed Nov 23, 2023
1 parent a85a6e0 commit 6cb9482
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 23 deletions.
2 changes: 1 addition & 1 deletion gymnasium/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def _async_worker(
truncated,
info,
) = env.step(data)
autoreset = terminated or truncated
autoreset = terminated or truncated

if shared_memory:
write_to_shared_memory(
Expand Down
5 changes: 3 additions & 2 deletions gymnasium/wrappers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,14 +468,14 @@ class RecordEpisodeStatistics(
def __init__(
self,
env: gym.Env[ObsType, ActType],
buffer_length: int | None = 100,
buffer_length: int = 100,
stats_key: str = "episode",
):
"""This wrapper will keep track of cumulative rewards and episode lengths.
Args:
env (Env): The environment to apply the wrapper
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
buffer_length: The size of the buffers :attr:`return_queue`, :attr:`length_queue` and :attr:`time_queue`
stats_key: The info key for the episode statistics
"""
gym.utils.RecordConstructorArgs.__init__(self)
Expand Down Expand Up @@ -518,6 +518,7 @@ def step(
self.length_queue.append(self.episode_lengths)

self.episode_count += 1
self.episode_start_time = time.perf_counter()

return obs, reward, terminated, truncated, info

Expand Down
30 changes: 15 additions & 15 deletions gymnasium/wrappers/vector/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,25 @@ class RecordEpisodeStatistics(VectorWrapper):
None, None], dtype=object)}
"""

def __init__(self, env: VectorEnv, deque_size: int = 100):
def __init__(self, env: VectorEnv, deque_size: int = 100,
stats_key: str = "episode",
):
"""This wrapper will keep track of cumulative rewards and episode lengths.
Args:
env (Env): The environment to apply the wrapper
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
"""
super().__init__(env)
self._stats_key = stats_key

self.episode_count = 0

self.episode_start_times: np.ndarray = np.zeros(())
self.episode_returns: np.ndarray = np.zeros(())
self.episode_lengths: np.ndarray = np.zeros(())

self.time_queue = deque(maxlen=deque_size)
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)

Expand All @@ -88,11 +92,9 @@ def reset(
"""Resets the environment using kwargs and resets the episode returns and lengths."""
obs, info = super().reset(seed=seed, options=options)

self.episode_start_times = np.full(
self.num_envs, time.perf_counter(), dtype=np.float32
)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.episode_start_times = np.full(self.num_envs, time.perf_counter())
self.episode_returns = np.zeros(self.num_envs)
self.episode_lengths = np.zeros(self.num_envs)

return obs, info

Expand All @@ -119,25 +121,23 @@ def step(
num_dones = np.sum(dones)

if num_dones:
if "episode" in infos or "_episode" in infos:
if self._stats_key in infos or f'_{self._stats_key}' in infos:
raise ValueError(
"Attempted to add episode stats when they already exist"
f"Attempted to add episode stats when they already exist, info keys: {list(infos.keys())}"
)
else:
infos["episode"] = {
episode_time_length = np.round(time.perf_counter() - self.episode_start_times, 6)
infos[self._stats_key] = {
"r": np.where(dones, self.episode_returns, 0.0),
"l": np.where(dones, self.episode_lengths, 0),
"t": np.where(
dones,
np.round(time.perf_counter() - self.episode_start_times, 6),
0.0,
),
"t": np.where(dones, episode_time_length, 0.0),
}
infos["_episode"] = dones
infos[f'_{self._stats_key}'] = dones

self.episode_count += num_dones

for i in np.where(dones):
self.time_queue.extend(episode_time_length[i])
self.return_queue.extend(self.episode_returns[i])
self.length_queue.extend(self.episode_lengths[i])

Expand Down
8 changes: 4 additions & 4 deletions gymnasium/wrappers/vector/stateful_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructor
>>> for _ in range(100):
... obs, *_ = envs.step(envs.action_space.sample())
>>> np.mean(obs)
-0.017698428
0.024251968
>>> np.std(obs)
0.62041104
0.62259156
>>> envs.close()
Example with the normalize reward wrapper:
Expand All @@ -48,9 +48,9 @@ class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructor
>>> for _ in range(100):
... obs, *_ = envs.step(envs.action_space.sample())
>>> np.mean(obs)
-0.28381696
-0.2359734
>>> np.std(obs)
1.21742
1.1938739
>>> envs.close()
"""

Expand Down
2 changes: 1 addition & 1 deletion tests/wrappers/vector/test_vector_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def custom_environments():
"RescaleAction",
{"min_action": 1, "max_action": 2},
),
("CartPole-v1", "ClipReward", {"min_reward": 0.25, "max_reward": 0.75}),
("CartPole-v1", "ClipReward", {"min_reward": -0.25, "max_reward": 0.75}),
),
)
def test_vector_wrapper_equivalence(
Expand Down

0 comments on commit 6cb9482

Please sign in to comment.