Skip to content

Commit

Permalink
Add NormalizeObservation wrapper observation space (#978)
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts authored Mar 22, 2024
1 parent e9fa737 commit 144feb8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
13 changes: 11 additions & 2 deletions gymnasium/wrappers/stateful_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ class NormalizeObservation(
Change logs:
* v0.21.0 - Initially add
* v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard, particularly useful for evaluation time.
Casts all observations to `np.float32` and sets the observation space with low/high of `-np.inf` and `np.inf` and dtype as `np.float32`
"""

def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):
Expand All @@ -499,6 +500,14 @@ def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
gym.ObservationWrapper.__init__(self, env)

assert env.observation_space.shape is not None
self.observation_space = gym.spaces.Box(
low=-np.inf,
high=np.inf,
shape=env.observation_space.shape,
dtype=np.float32,
)

self.obs_rms = RunningMeanStd(
shape=self.observation_space.shape, dtype=self.observation_space.dtype
)
Expand All @@ -519,8 +528,8 @@ def observation(self, observation: ObsType) -> WrapperObsType:
"""Normalises the observation using the running mean and variance of the observations."""
if self._update_running_mean:
self.obs_rms.update(np.array([observation]))
return (observation - self.obs_rms.mean) / np.sqrt(
self.obs_rms.var + self.epsilon
return np.float32(
(observation - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon)
)


Expand Down
12 changes: 12 additions & 0 deletions tests/wrappers/test_normalize_observation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test suite for NormalizeObservation wrapper."""
import numpy as np

import gymnasium as gym
from gymnasium import spaces, wrappers
from gymnasium.wrappers import NormalizeObservation
from tests.testing_env import GenericTestEnv
Expand Down Expand Up @@ -62,3 +63,14 @@ def test_update_running_mean_property():
wrapped_env.step(None)
assert rms_var_updated == wrapped_env.obs_rms.var
assert rms_mean_updated == wrapped_env.obs_rms.mean


def test_normalize_obs_with_vector():
def thunk():
env = gym.make("CarRacing-v2")
env = gym.wrappers.GrayscaleObservation(env)
env = gym.wrappers.NormalizeObservation(env)
return env

envs = gym.vector.SyncVectorEnv([thunk for _ in range(4)])
obs, _ = envs.reset()

0 comments on commit 144feb8

Please sign in to comment.