Skip to content

Commit

Permalink
Improve the normalize observation testing and determinism (#784)
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts authored Nov 22, 2023
1 parent e8858f0 commit b312a03
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 25 deletions.
31 changes: 31 additions & 0 deletions tests/wrappers/test_normalize_observation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,40 @@
"""Test suite for NormalizeObservation wrapper."""
import numpy as np

from gymnasium import spaces, wrappers
from gymnasium.wrappers import NormalizeObservation
from tests.testing_env import GenericTestEnv


def test_normalization(convergence_steps: int = 1000, testing_steps: int = 100):
env = GenericTestEnv(
observation_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32),
high=np.array([10, -5, 10], dtype=np.float32),
)
)
env = wrappers.NormalizeObservation(env)

env.reset(seed=123)
env.observation_space.seed(123)
env.action_space.seed(123)
for _ in range(convergence_steps):
env.step(env.action_space.sample())

observations = []
for _ in range(testing_steps):
obs, *_ = env.step(env.action_space.sample())
observations.append(obs)
observations = np.array(observations) # (100, 3)

mean_obs = np.mean(observations, axis=0)
var_obs = np.var(observations, axis=0)
assert mean_obs.shape == (3,) and var_obs.shape == (3,)

assert np.allclose(mean_obs, np.zeros(3), atol=0.15)
assert np.allclose(var_obs, np.ones(3), atol=0.15)


def test_update_running_mean_property():
"""Tests that the property `_update_running_mean` freezes/continues the running statistics updating."""
env = GenericTestEnv()
Expand Down
80 changes: 55 additions & 25 deletions tests/wrappers/vector/test_normalize_observation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""Test suite for vector NormalizeObservation wrapper.."""
"""Test suite for vector NormalizeObservation wrapper."""
import numpy as np

from gymnasium import spaces, wrappers
from gymnasium.vector import SyncVectorEnv
from gymnasium.vector.utils import create_empty_array
from tests.testing_env import GenericTestEnv


def thunk():
def create_env():
return GenericTestEnv(
observation_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32),
Expand All @@ -16,55 +15,86 @@ def thunk():
)


def test_against_wrapper(
n_envs=3,
n_steps=250,
def test_normalization(
n_envs: int = 2, convergence_steps: int = 250, testing_steps: int = 100
):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.NormalizeObservation(vec_env)

vec_env.reset(seed=123)
vec_env.observation_space.seed(123)
vec_env.action_space.seed(123)
for _ in range(convergence_steps):
vec_env.step(vec_env.action_space.sample())

observations = []
for _ in range(testing_steps):
obs, *_ = vec_env.step(vec_env.action_space.sample())
observations.append(obs)
observations = np.array(observations) # (100, 2, 3)

mean_obs = np.mean(observations, axis=(0, 1))
var_obs = np.var(observations, axis=(0, 1))
assert mean_obs.shape == (3,) and var_obs.shape == (3,)

assert np.allclose(mean_obs, np.zeros(3), atol=0.15)
assert np.allclose(var_obs, np.ones(3), atol=0.2)


def test_wrapper_equivalence(
n_envs: int = 3,
n_steps: int = 250,
mean_rtol=np.array([0.1, 0.4, 0.25]),
var_rtol=np.array([0.15, 0.15, 0.18]),
):
vec_env = SyncVectorEnv([thunk for _ in range(n_envs)])
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.NormalizeObservation(vec_env)

vec_env.reset(seed=123)
vec_env.observation_space.seed(123)
vec_env.action_space.seed(123)
for _ in range(n_steps):
vec_env.step(vec_env.action_space.sample())

env = wrappers.Autoreset(thunk())
env = wrappers.Autoreset(create_env())
env = wrappers.NormalizeObservation(env)
env.reset(seed=123)
env.action_space.seed(123)
for _ in range(n_envs * n_steps):
for _ in range(n_steps // n_envs):
env.step(env.action_space.sample())

assert np.allclose(env.obs_rms.mean, vec_env.obs_rms.mean, rtol=mean_rtol)
assert np.allclose(env.obs_rms.var, vec_env.obs_rms.var, rtol=var_rtol)


def test_update_running_mean():
env = SyncVectorEnv([thunk for _ in range(2)])
env = SyncVectorEnv([create_env for _ in range(2)])
env = wrappers.vector.NormalizeObservation(env)

# Default value is True
assert env.update_running_mean

obs, _ = env.reset()
env.reset()
for _ in range(100):
env.step(env.action_space.sample())

# Disable
# Disable updating the running mean
env.update_running_mean = False
rms_mean = np.copy(env.obs_rms.mean)
rms_var = np.copy(env.obs_rms.var)
copied_rms_mean = np.copy(env.obs_rms.mean)
copied_rms_var = np.copy(env.obs_rms.var)

val_step = 25
obs_buffer = create_empty_array(env.observation_space, val_step)
env.action_space.seed(123)
for i in range(val_step):
obs, _, _, _, _ = env.step(env.action_space.sample())
obs_buffer[i] = obs

assert np.all(rms_mean == env.obs_rms.mean)
assert np.all(rms_var == env.obs_rms.var)
assert np.allclose(np.mean(obs_buffer, axis=(0, 1)), 0, atol=0.5)
assert np.allclose(np.var(obs_buffer, axis=(0, 1)), 1, atol=0.5)
# Continue stepping through the environment and check that the running mean is not effected
for i in range(10):
env.step(env.action_space.sample())

assert np.all(copied_rms_mean == env.obs_rms.mean)
assert np.all(copied_rms_var == env.obs_rms.var)

# Re-enable updating the running mean
env.update_running_mean = True

for i in range(10):
env.step(env.action_space.sample())

assert np.any(copied_rms_mean != env.obs_rms.mean)
assert np.any(copied_rms_var != env.obs_rms.var)

0 comments on commit b312a03

Please sign in to comment.