diff --git a/tests/wrappers/test_normalize_observation.py b/tests/wrappers/test_normalize_observation.py index 1fccc3b97..0fff7a995 100644 --- a/tests/wrappers/test_normalize_observation.py +++ b/tests/wrappers/test_normalize_observation.py @@ -1,9 +1,40 @@ """Test suite for NormalizeObservation wrapper.""" +import numpy as np +from gymnasium import spaces, wrappers from gymnasium.wrappers import NormalizeObservation from tests.testing_env import GenericTestEnv +def test_normalization(convergence_steps: int = 1000, testing_steps: int = 100): + env = GenericTestEnv( + observation_space=spaces.Box( + low=np.array([0, -10, -5], dtype=np.float32), + high=np.array([10, -5, 10], dtype=np.float32), + ) + ) + env = wrappers.NormalizeObservation(env) + + env.reset(seed=123) + env.observation_space.seed(123) + env.action_space.seed(123) + for _ in range(convergence_steps): + env.step(env.action_space.sample()) + + observations = [] + for _ in range(testing_steps): + obs, *_ = env.step(env.action_space.sample()) + observations.append(obs) + observations = np.array(observations) # (100, 3) + + mean_obs = np.mean(observations, axis=0) + var_obs = np.var(observations, axis=0) + assert mean_obs.shape == (3,) and var_obs.shape == (3,) + + assert np.allclose(mean_obs, np.zeros(3), atol=0.15) + assert np.allclose(var_obs, np.ones(3), atol=0.15) + + def test_update_running_mean_property(): """Tests that the property `_update_running_mean` freezes/continues the running statistics updating.""" env = GenericTestEnv() diff --git a/tests/wrappers/vector/test_normalize_observation.py b/tests/wrappers/vector/test_normalize_observation.py index 773ca4037..3dab73b00 100644 --- a/tests/wrappers/vector/test_normalize_observation.py +++ b/tests/wrappers/vector/test_normalize_observation.py @@ -1,13 +1,12 @@ -"""Test suite for vector NormalizeObservation wrapper..""" +"""Test suite for vector NormalizeObservation wrapper.""" import numpy as np from gymnasium import spaces, wrappers from gymnasium.vector import SyncVectorEnv -from gymnasium.vector.utils import create_empty_array from tests.testing_env import GenericTestEnv -def thunk(): +def create_env(): return GenericTestEnv( observation_space=spaces.Box( low=np.array([0, -10, -5], dtype=np.float32), @@ -16,25 +15,52 @@ def thunk(): ) -def test_against_wrapper( - n_envs=3, - n_steps=250, +def test_normalization( + n_envs: int = 2, convergence_steps: int = 250, testing_steps: int = 100 +): + vec_env = SyncVectorEnv([create_env for _ in range(n_envs)]) + vec_env = wrappers.vector.NormalizeObservation(vec_env) + + vec_env.reset(seed=123) + vec_env.observation_space.seed(123) + vec_env.action_space.seed(123) + for _ in range(convergence_steps): + vec_env.step(vec_env.action_space.sample()) + + observations = [] + for _ in range(testing_steps): + obs, *_ = vec_env.step(vec_env.action_space.sample()) + observations.append(obs) + observations = np.array(observations) # (100, 2, 3) + + mean_obs = np.mean(observations, axis=(0, 1)) + var_obs = np.var(observations, axis=(0, 1)) + assert mean_obs.shape == (3,) and var_obs.shape == (3,) + + assert np.allclose(mean_obs, np.zeros(3), atol=0.15) + assert np.allclose(var_obs, np.ones(3), atol=0.2) + + +def test_wrapper_equivalence( + n_envs: int = 3, + n_steps: int = 250, mean_rtol=np.array([0.1, 0.4, 0.25]), var_rtol=np.array([0.15, 0.15, 0.18]), ): - vec_env = SyncVectorEnv([thunk for _ in range(n_envs)]) + vec_env = SyncVectorEnv([create_env for _ in range(n_envs)]) vec_env = wrappers.vector.NormalizeObservation(vec_env) vec_env.reset(seed=123) + vec_env.observation_space.seed(123) vec_env.action_space.seed(123) for _ in range(n_steps): vec_env.step(vec_env.action_space.sample()) - env = wrappers.Autoreset(thunk()) + env = wrappers.Autoreset(create_env()) env = wrappers.NormalizeObservation(env) env.reset(seed=123) env.action_space.seed(123) - for _ in range(n_envs * n_steps): + for _ in range(n_steps // n_envs): env.step(env.action_space.sample()) assert np.allclose(env.obs_rms.mean, vec_env.obs_rms.mean, rtol=mean_rtol) @@ -42,29 +68,33 @@ def test_against_wrapper( def test_update_running_mean(): - env = SyncVectorEnv([thunk for _ in range(2)]) + env = SyncVectorEnv([create_env for _ in range(2)]) env = wrappers.vector.NormalizeObservation(env) # Default value is True assert env.update_running_mean - obs, _ = env.reset() + env.reset() for _ in range(100): env.step(env.action_space.sample()) - # Disable + # Disable updating the running mean env.update_running_mean = False - rms_mean = np.copy(env.obs_rms.mean) - rms_var = np.copy(env.obs_rms.var) + copied_rms_mean = np.copy(env.obs_rms.mean) + copied_rms_var = np.copy(env.obs_rms.var) - val_step = 25 - obs_buffer = create_empty_array(env.observation_space, val_step) - env.action_space.seed(123) - for i in range(val_step): - obs, _, _, _, _ = env.step(env.action_space.sample()) - obs_buffer[i] = obs - - assert np.all(rms_mean == env.obs_rms.mean) - assert np.all(rms_var == env.obs_rms.var) - assert np.allclose(np.mean(obs_buffer, axis=(0, 1)), 0, atol=0.5) - assert np.allclose(np.var(obs_buffer, axis=(0, 1)), 1, atol=0.5) + # Continue stepping through the environment and check that the running mean is not effected + for i in range(10): + env.step(env.action_space.sample()) + + assert np.all(copied_rms_mean == env.obs_rms.mean) + assert np.all(copied_rms_var == env.obs_rms.var) + + # Re-enable updating the running mean + env.update_running_mean = True + + for i in range(10): + env.step(env.action_space.sample()) + + assert np.any(copied_rms_mean != env.obs_rms.mean) + assert np.any(copied_rms_var != env.obs_rms.var)