Skip to content

Commit

Permalink
Fix experimental normalize reward wrapper (openai#277)
Browse files Browse the repository at this point in the history
Co-authored-by: raphajaner <[email protected]>
  • Loading branch information
raphajaner and raphajaner authored Jan 20, 2023
1 parent b4caf9d commit 4b5abb6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
6 changes: 5 additions & 1 deletion gymnasium/experimental/wrappers/lambda_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ class NormalizeRewardV0(gym.Wrapper):
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the reward
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called.
If False, the calculated statistics are used but not updated anymore; this may be used during evaluation.
Note:
The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
instantiated or the policy was changed recently.
Expand All @@ -118,7 +122,7 @@ def __init__(
"""
super().__init__(env)
self.rewards_running_means = RunningMeanStd(shape=())
self.discounted_reward: float = 0.0
self.discounted_reward: np.array = np.array([0.0])
self.gamma = gamma
self.epsilon = epsilon

Expand Down
25 changes: 25 additions & 0 deletions tests/experimental/wrappers/test_normalize_reward.py
Original file line number Diff line number Diff line change
@@ -1 +1,26 @@
"""Test suite for NormalizeRewardV0."""
import numpy as np

from gymnasium.core import ActType
from gymnasium.experimental.wrappers import NormalizeRewardV0
from tests.testing_env import GenericTestEnv


def _make_reward_env():
"""Function that returns a `GenericTestEnv` with reward=1."""

def step_func(self, action: ActType):
return self.observation_space.sample(), 1.0, False, False, {}

return GenericTestEnv(step_func=step_func)


def test_normalize_reward_wrapper():
"""Tests that the NormalizeReward does not throw an error."""
# TODO: Functional correctness should be tested
env = _make_reward_env()
wrapped_env = NormalizeRewardV0(env)
wrapped_env.reset()
_, reward, _, _, _ = wrapped_env.step(None)
assert np.ndim(reward) == 0
env.close()

0 comments on commit 4b5abb6

Please sign in to comment.