Skip to content

Commit

Permalink
Adding a self._update_running_mean property to block updating the sta…
Browse files Browse the repository at this point in the history
…tistics of NormalizeX Wrappers (openai#268)

Co-authored-by: raphajaner <[email protected]>
  • Loading branch information
raphajaner and raphajaner authored Jan 20, 2023
1 parent 57819c4 commit cd8dc81
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 2 deletions.
18 changes: 17 additions & 1 deletion gymnasium/experimental/wrappers/lambda_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,10 @@ def __init__(
class NormalizeObservationV0(ObservationWrapper):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.observation()` is called.
If `False`, the calculated statistics are used but not updated anymore; this may be used during evaluation.
Note:
The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was
newly instantiated or the policy was changed recently.
Expand All @@ -518,10 +522,22 @@ def __init__(self, env: gym.Env, epsilon: float = 1e-8):
super().__init__(env)
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.epsilon = epsilon
self._update_running_mean = True

@property
def update_running_mean(self) -> bool:
"""Property to freeze/continue the running mean calculation of the observation statistics."""
return self._update_running_mean

@update_running_mean.setter
def update_running_mean(self, setting: bool):
"""Sets the property to freeze/continue the running mean calculation of the observation statistics."""
self._update_running_mean = setting

def observation(self, observation: ObsType) -> WrapperObsType:
"""Normalises the observation using the running mean and variance of the observations."""
self.obs_rms.update(observation)
if self._update_running_mean:
self.obs_rms.update(observation)
return (observation - self.obs_rms.mean) / np.sqrt(
self.obs_rms.var + self.epsilon
)
14 changes: 13 additions & 1 deletion gymnasium/experimental/wrappers/lambda_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ def __init__(
self.discounted_reward: np.array = np.array([0.0])
self.gamma = gamma
self.epsilon = epsilon
self._update_running_mean = True

@property
def update_running_mean(self) -> bool:
"""Property to freeze/continue the running mean calculation of the reward statistics."""
return self._update_running_mean

@update_running_mean.setter
def update_running_mean(self, setting: bool):
"""Sets the property to freeze/continue the running mean calculation of the reward statistics."""
self._update_running_mean = setting

def step(
self, action: WrapperActType
Expand All @@ -138,5 +149,6 @@ def step(

def normalize(self, reward):
"""Normalizes the rewards with the running mean rewards and their variance."""
self.rewards_running_means.update(self.discounted_reward)
if self._update_running_mean:
self.rewards_running_means.update(self.discounted_reward)
return reward / np.sqrt(self.rewards_running_means.var + self.epsilon)
31 changes: 31 additions & 0 deletions tests/experimental/wrappers/test_normalize_observation.py
Original file line number Diff line number Diff line change
@@ -1 +1,32 @@
"""Test suite for NormalizeObservationV0."""
from gymnasium.experimental.wrappers import NormalizeObservationV0
from tests.testing_env import GenericTestEnv


def test_running_mean_normalize_observation_wrapper():
"""Tests that the property `_update_running_mean` freezes/continues the running statistics updating."""
env = GenericTestEnv()
wrapped_env = NormalizeObservationV0(env)

# Default value is True
assert wrapped_env.update_running_mean

wrapped_env.reset()
rms_var_init = wrapped_env.obs_rms.var
rms_mean_init = wrapped_env.obs_rms.mean

# Statistics are updated when env.step()
wrapped_env.step(None)
rms_var_updated = wrapped_env.obs_rms.var
rms_mean_updated = wrapped_env.obs_rms.mean
assert rms_var_init != rms_var_updated
assert rms_mean_init != rms_mean_updated

# Assure property is set
wrapped_env.update_running_mean = False
assert not wrapped_env.update_running_mean

# Statistics are frozen
wrapped_env.step(None)
assert rms_var_updated == wrapped_env.obs_rms.var
assert rms_mean_updated == wrapped_env.obs_rms.mean
29 changes: 29 additions & 0 deletions tests/experimental/wrappers/test_normalize_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,35 @@ def step_func(self, action: ActType):
return GenericTestEnv(step_func=step_func)


def test_running_mean_normalize_reward_wrapper():
"""Tests that the property `_update_running_mean` freezes/continues the running statistics updating."""
env = _make_reward_env()
wrapped_env = NormalizeRewardV0(env)

# Default value is True
assert wrapped_env.update_running_mean

wrapped_env.reset()
rms_var_init = wrapped_env.rewards_running_means.var
rms_mean_init = wrapped_env.rewards_running_means.mean

# Statistics are updated when env.step()
wrapped_env.step(None)
rms_var_updated = wrapped_env.rewards_running_means.var
rms_mean_updated = wrapped_env.rewards_running_means.mean
assert rms_var_init != rms_var_updated
assert rms_mean_init != rms_mean_updated

# Assure property is set
wrapped_env.update_running_mean = False
assert not wrapped_env.update_running_mean

# Statistics are frozen
wrapped_env.step(None)
assert rms_var_updated == wrapped_env.rewards_running_means.var
assert rms_mean_updated == wrapped_env.rewards_running_means.mean


def test_normalize_reward_wrapper():
"""Tests that the NormalizeReward does not throw an error."""
# TODO: Functional correctness should be tested
Expand Down

0 comments on commit cd8dc81

Please sign in to comment.