From 4b5abb694b52eed9b06a9c53fa6d0d0602467baf Mon Sep 17 00:00:00 2001 From: raphajaner Date: Fri, 20 Jan 2023 15:25:31 +0100 Subject: [PATCH] Fix experimental normalize reward wrapper (#277) Co-authored-by: raphajaner --- .../experimental/wrappers/lambda_reward.py | 6 ++++- .../wrappers/test_normalize_reward.py | 25 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/gymnasium/experimental/wrappers/lambda_reward.py b/gymnasium/experimental/wrappers/lambda_reward.py index d95ddec8b..db3b3d690 100644 --- a/gymnasium/experimental/wrappers/lambda_reward.py +++ b/gymnasium/experimental/wrappers/lambda_reward.py @@ -98,6 +98,10 @@ class NormalizeRewardV0(gym.Wrapper): The exponential moving average will have variance :math:`(1 - \gamma)^2`. + The property `_update_running_mean` allows to freeze/continue the running mean calculation of the reward + statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called. + If False, the calculated statistics are used but not updated anymore; this may be used during evaluation. + Note: The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly instantiated or the policy was changed recently. @@ -118,7 +122,7 @@ def __init__( """ super().__init__(env) self.rewards_running_means = RunningMeanStd(shape=()) - self.discounted_reward: float = 0.0 + self.discounted_reward: np.array = np.array([0.0]) self.gamma = gamma self.epsilon = epsilon diff --git a/tests/experimental/wrappers/test_normalize_reward.py b/tests/experimental/wrappers/test_normalize_reward.py index 845cd9059..e2c643b07 100644 --- a/tests/experimental/wrappers/test_normalize_reward.py +++ b/tests/experimental/wrappers/test_normalize_reward.py @@ -1 +1,26 @@ """Test suite for NormalizeRewardV0.""" +import numpy as np + +from gymnasium.core import ActType +from gymnasium.experimental.wrappers import NormalizeRewardV0 +from tests.testing_env import GenericTestEnv + + +def _make_reward_env(): + """Function that returns a `GenericTestEnv` with reward=1.""" + + def step_func(self, action: ActType): + return self.observation_space.sample(), 1.0, False, False, {} + + return GenericTestEnv(step_func=step_func) + + +def test_normalize_reward_wrapper(): + """Tests that the NormalizeReward does not throw an error.""" + # TODO: Functional correctness should be tested + env = _make_reward_env() + wrapped_env = NormalizeRewardV0(env) + wrapped_env.reset() + _, reward, _, _, _ = wrapped_env.step(None) + assert np.ndim(reward) == 0 + env.close()