diff --git a/gymnasium/experimental/wrappers/lambda_observations.py b/gymnasium/experimental/wrappers/lambda_observations.py index 1394b1d4d..2db98740f 100644 --- a/gymnasium/experimental/wrappers/lambda_observations.py +++ b/gymnasium/experimental/wrappers/lambda_observations.py @@ -503,6 +503,10 @@ def __init__( class NormalizeObservationV0(ObservationWrapper): """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. + The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation + statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.observation()` is called. + If `False`, the calculated statistics are used but not updated anymore; this may be used during evaluation. + Note: The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was newly instantiated or the policy was changed recently. @@ -518,10 +522,22 @@ def __init__(self, env: gym.Env, epsilon: float = 1e-8): super().__init__(env) self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) self.epsilon = epsilon + self._update_running_mean = True + + @property + def update_running_mean(self) -> bool: + """Property to freeze/continue the running mean calculation of the observation statistics.""" + return self._update_running_mean + + @update_running_mean.setter + def update_running_mean(self, setting: bool): + """Sets the property to freeze/continue the running mean calculation of the observation statistics.""" + self._update_running_mean = setting def observation(self, observation: ObsType) -> WrapperObsType: """Normalises the observation using the running mean and variance of the observations.""" - self.obs_rms.update(observation) + if self._update_running_mean: + self.obs_rms.update(observation) return (observation - self.obs_rms.mean) / np.sqrt( self.obs_rms.var + self.epsilon ) diff --git a/gymnasium/experimental/wrappers/lambda_reward.py b/gymnasium/experimental/wrappers/lambda_reward.py index db3b3d690..51917c371 100644 --- a/gymnasium/experimental/wrappers/lambda_reward.py +++ b/gymnasium/experimental/wrappers/lambda_reward.py @@ -125,6 +125,17 @@ def __init__( self.discounted_reward: np.array = np.array([0.0]) self.gamma = gamma self.epsilon = epsilon + self._update_running_mean = True + + @property + def update_running_mean(self) -> bool: + """Property to freeze/continue the running mean calculation of the reward statistics.""" + return self._update_running_mean + + @update_running_mean.setter + def update_running_mean(self, setting: bool): + """Sets the property to freeze/continue the running mean calculation of the reward statistics.""" + self._update_running_mean = setting def step( self, action: WrapperActType @@ -138,5 +149,6 @@ def step( def normalize(self, reward): """Normalizes the rewards with the running mean rewards and their variance.""" - self.rewards_running_means.update(self.discounted_reward) + if self._update_running_mean: + self.rewards_running_means.update(self.discounted_reward) return reward / np.sqrt(self.rewards_running_means.var + self.epsilon) diff --git a/tests/experimental/wrappers/test_normalize_observation.py b/tests/experimental/wrappers/test_normalize_observation.py index 6e4b8c7c4..22889dbe7 100644 --- a/tests/experimental/wrappers/test_normalize_observation.py +++ b/tests/experimental/wrappers/test_normalize_observation.py @@ -1 +1,32 @@ """Test suite for NormalizeObservationV0.""" +from gymnasium.experimental.wrappers import NormalizeObservationV0 +from tests.testing_env import GenericTestEnv + + +def test_running_mean_normalize_observation_wrapper(): + """Tests that the property `_update_running_mean` freezes/continues the running statistics updating.""" + env = GenericTestEnv() + wrapped_env = NormalizeObservationV0(env) + + # Default value is True + assert wrapped_env.update_running_mean + + wrapped_env.reset() + rms_var_init = wrapped_env.obs_rms.var + rms_mean_init = wrapped_env.obs_rms.mean + + # Statistics are updated when env.step() + wrapped_env.step(None) + rms_var_updated = wrapped_env.obs_rms.var + rms_mean_updated = wrapped_env.obs_rms.mean + assert rms_var_init != rms_var_updated + assert rms_mean_init != rms_mean_updated + + # Assure property is set + wrapped_env.update_running_mean = False + assert not wrapped_env.update_running_mean + + # Statistics are frozen + wrapped_env.step(None) + assert rms_var_updated == wrapped_env.obs_rms.var + assert rms_mean_updated == wrapped_env.obs_rms.mean diff --git a/tests/experimental/wrappers/test_normalize_reward.py b/tests/experimental/wrappers/test_normalize_reward.py index e2c643b07..7045b8650 100644 --- a/tests/experimental/wrappers/test_normalize_reward.py +++ b/tests/experimental/wrappers/test_normalize_reward.py @@ -15,6 +15,35 @@ def step_func(self, action: ActType): return GenericTestEnv(step_func=step_func) +def test_running_mean_normalize_reward_wrapper(): + """Tests that the property `_update_running_mean` freezes/continues the running statistics updating.""" + env = _make_reward_env() + wrapped_env = NormalizeRewardV0(env) + + # Default value is True + assert wrapped_env.update_running_mean + + wrapped_env.reset() + rms_var_init = wrapped_env.rewards_running_means.var + rms_mean_init = wrapped_env.rewards_running_means.mean + + # Statistics are updated when env.step() + wrapped_env.step(None) + rms_var_updated = wrapped_env.rewards_running_means.var + rms_mean_updated = wrapped_env.rewards_running_means.mean + assert rms_var_init != rms_var_updated + assert rms_mean_init != rms_mean_updated + + # Assure property is set + wrapped_env.update_running_mean = False + assert not wrapped_env.update_running_mean + + # Statistics are frozen + wrapped_env.step(None) + assert rms_var_updated == wrapped_env.rewards_running_means.var + assert rms_mean_updated == wrapped_env.rewards_running_means.mean + + def test_normalize_reward_wrapper(): """Tests that the NormalizeReward does not throw an error.""" # TODO: Functional correctness should be tested