forked from openai/retro
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding a self._update_running_mean property to block updating the sta…
…tistics of NormalizeX Wrappers (openai#268) Co-authored-by: raphajaner <[email protected]>
- Loading branch information
1 parent
57819c4
commit cd8dc81
Showing
4 changed files
with
90 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,32 @@ | ||
"""Test suite for NormalizeObservationV0.""" | ||
from gymnasium.experimental.wrappers import NormalizeObservationV0 | ||
from tests.testing_env import GenericTestEnv | ||
|
||
|
||
def test_running_mean_normalize_observation_wrapper(): | ||
"""Tests that the property `_update_running_mean` freezes/continues the running statistics updating.""" | ||
env = GenericTestEnv() | ||
wrapped_env = NormalizeObservationV0(env) | ||
|
||
# Default value is True | ||
assert wrapped_env.update_running_mean | ||
|
||
wrapped_env.reset() | ||
rms_var_init = wrapped_env.obs_rms.var | ||
rms_mean_init = wrapped_env.obs_rms.mean | ||
|
||
# Statistics are updated when env.step() | ||
wrapped_env.step(None) | ||
rms_var_updated = wrapped_env.obs_rms.var | ||
rms_mean_updated = wrapped_env.obs_rms.mean | ||
assert rms_var_init != rms_var_updated | ||
assert rms_mean_init != rms_mean_updated | ||
|
||
# Assure property is set | ||
wrapped_env.update_running_mean = False | ||
assert not wrapped_env.update_running_mean | ||
|
||
# Statistics are frozen | ||
wrapped_env.step(None) | ||
assert rms_var_updated == wrapped_env.obs_rms.var | ||
assert rms_mean_updated == wrapped_env.obs_rms.mean |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters