Skip to content

Commit

Permalink
Update the wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed Nov 22, 2023
1 parent 8a9159b commit a85a6e0
Show file tree
Hide file tree
Showing 11 changed files with 166 additions and 192 deletions.
45 changes: 13 additions & 32 deletions gymnasium/vector/vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,23 +444,23 @@ def reset(
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
obs, info = self.env.reset(seed=seed, options=options)
return self.vector_observation(obs), info
observations, infos = self.env.reset(seed=seed, options=options)
return self.observation(observations), infos

def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
observation, reward, termination, truncation, info = self.env.step(actions)
observations, rewards, terminations, truncations, infos = self.env.step(actions)
return (
self.vector_observation(observation),
reward,
termination,
truncation,
self.update_final_obs(info),
self.observation(observations),
rewards,
terminations,
truncations,
infos,
)

def vector_observation(self, observation: ObsType) -> ObsType:
def observation(self, observation: ObsType) -> ObsType:
"""Defines the vector observation transformation.
Args:
Expand All @@ -471,25 +471,6 @@ def vector_observation(self, observation: ObsType) -> ObsType:
"""
raise NotImplementedError

def single_observation(self, observation: ObsType) -> ObsType:
"""Defines the single observation transformation.
Args:
observation: A single observation from the environment
Returns:
The transformed observation
"""
raise NotImplementedError

def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]:
"""Updates the `final_obs` in the info using `single_observation`."""
if "final_observation" in info:
for i, obs in enumerate(info["final_observation"]):
if obs is not None:
info["final_observation"][i] = self.single_observation(obs)
return info


class VectorActionWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the actions.
Expand Down Expand Up @@ -525,14 +506,14 @@ def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Steps through the environment returning a reward modified by :meth:`reward`."""
observation, reward, termination, truncation, info = self.env.step(actions)
return observation, self.rewards(reward), termination, truncation, info
observations, rewards, terminations, truncations, infos = self.env.step(actions)
return observations, self.rewards(rewards), terminations, truncations, infos

def rewards(self, reward: ArrayType) -> ArrayType:
def rewards(self, rewards: ArrayType) -> ArrayType:
"""Transform the reward before returning it.
Args:
reward (array): the reward to transform
rewards (array): the reward to transform
Returns:
array: the transformed reward
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/wrappers/stateful_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ class NormalizeObservation(
Change logs:
* v0.21.0 - Initially add
* v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard
* v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard, particularly useful for evaluation time.
"""

def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/wrappers/vector/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def step(

assert isinstance(
infos, dict
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
), f"`vector.RecordEpisodeStatistics` requires `info` type to be `dict`, its actual type is {type(infos)}. This may be due to usage of other wrappers in the wrong order."

self.episode_returns += rewards
self.episode_lengths += 1
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/wrappers/vector/dict_info_to_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def step(
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]:
"""Steps through the environment, convert dict info to list."""
observation, reward, terminated, truncated, infos = self.env.step(actions)
assert isinstance(infos, dict)
list_info = self._convert_info_to_list(infos)

return observation, reward, terminated, truncated, list_info
Expand All @@ -92,6 +93,7 @@ def reset(
) -> tuple[ObsType, list[dict[str, Any]]]:
"""Resets the environment using kwargs."""
obs, infos = self.env.reset(seed=seed, options=options)
assert isinstance(infos, dict)
list_info = self._convert_info_to_list(infos)

return obs, list_info
Expand Down
18 changes: 2 additions & 16 deletions gymnasium/wrappers/vector/stateful_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,29 +81,15 @@ def update_running_mean(self, setting: bool):
"""Sets the property to freeze/continue the running mean calculation of the observation statistics."""
self._update_running_mean = setting

def vector_observation(self, observation: ObsType) -> ObsType:
def observation(self, observations: ObsType) -> ObsType:
"""Defines the vector observation normalization function.
Args:
observation: A vector observation from the environment
observations: A vector observation from the environment
Returns:
the normalized observation
"""
return self._normalize_observations(observation)

def single_observation(self, observation: ObsType) -> ObsType:
"""Defines the single observation normalization function.
Args:
observation: A single observation from the environment
Returns:
The normalized observation
"""
return self._normalize_observations(observation[None])

def _normalize_observations(self, observations: ObsType) -> ObsType:
if self._update_running_mean:
self.obs_rms.update(observations)
return (observations - self.obs_rms.mean) / np.sqrt(
Expand Down
34 changes: 10 additions & 24 deletions gymnasium/wrappers/vector/vectorize_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,10 @@ class TransformObservation(VectorObservationWrapper):
>>> def scale_and_shift(obs):
... return (obs - 1.0) * 2.0
...
>>> def vector_scale_and_shift(obs):
... return (obs - 1.0) * 2.0
...
>>> import gymnasium as gym
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
>>> new_obs_space = Box(low=envs.observation_space.low, high=envs.observation_space.high)
>>> envs = TransformObservation(envs, single_func=scale_and_shift, vector_func=vector_scale_and_shift)
>>> envs = TransformObservation(envs, func=scale_and_shift, observation_space=new_obs_space)
>>> obs, info = envs.reset(seed=123)
>>> obs
array([[-1.9635296, -2.0892358, -2.055928 , -2.0631256],
Expand All @@ -55,33 +52,26 @@ class TransformObservation(VectorObservationWrapper):
def __init__(
self,
env: VectorEnv,
vector_func: Callable[[ObsType], Any],
single_func: Callable[[ObsType], Any],
func: Callable[[ObsType], Any],
observation_space: Space | None = None,
):
"""Constructor for the transform observation wrapper.
Args:
env: The vector environment to wrap
vector_func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
single_func: A function that will transform an individual observation, this function will be used for the final observation from the environment and is returned under ``info`` and not the normal observation.
func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
"""
super().__init__(env)

if observation_space is not None:
self.observation_space = observation_space

self.vector_func = vector_func
self.single_func = single_func
self.func = func

def vector_observation(self, observation: ObsType) -> ObsType:
def observation(self, observations: ObsType) -> ObsType:
"""Apply function to the vector observation."""
return self.vector_func(observation)

def single_observation(self, observation: ObsType) -> ObsType:
"""Apply function to the single observation."""
return self.single_func(observation)
return self.func(observations)


class VectorizeTransformObservation(VectorObservationWrapper):
Expand Down Expand Up @@ -158,33 +148,29 @@ def __init__(
self.same_out = self.observation_space == self.env.observation_space
self.out = create_empty_array(self.single_observation_space, self.num_envs)

def vector_observation(self, observation: ObsType) -> ObsType:
def observation(self, observations: ObsType) -> ObsType:
"""Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again."""
if self.same_out:
return concatenate(
self.single_observation_space,
tuple(
self.wrapper.func(obs)
for obs in iterate(self.observation_space, observation)
for obs in iterate(self.observation_space, observations)
),
observation,
observations,
)
else:
return deepcopy(
concatenate(
self.single_observation_space,
tuple(
self.wrapper.func(obs)
for obs in iterate(self.env.observation_space, observation)
for obs in iterate(self.env.observation_space, observations)
),
self.out,
)
)

def single_observation(self, observation: ObsType) -> ObsType:
"""Transforms a single observation using the wrapper transformation function."""
return self.wrapper.func(observation)


class FilterObservation(VectorizeTransformObservation):
"""Vector wrapper for filtering dict or tuple observation spaces.
Expand Down
2 changes: 1 addition & 1 deletion tests/vector/test_async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_create_async_vector_env(shared_memory):

@pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_async_vector_env(shared_memory):
"""Test the reset of an sync vector environment with or without shared memory."""
"""Test the reset of async vector environment with or without shared memory."""
env_fns = [make_env("CartPole-v1", i) for i in range(8)]

env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
Expand Down
115 changes: 52 additions & 63 deletions tests/vector/test_vector_env_info.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
"""Test the vector environment information."""
from __future__ import annotations

from typing import Any, SupportsFloat

import numpy as np
import pytest

import gymnasium as gym
from gymnasium.spaces import Discrete
from gymnasium.core import ActType, ObsType
from gymnasium.spaces import Box, Discrete
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.vector import VectorEnv
from gymnasium.vector.sync_vector_env import SyncVectorEnv
from tests.vector.testing_utils import make_env
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv


def test_examples():
def test_vector_add_info():
env = VectorEnv()

# Test num-envs==1 then expand_dims(sub-env-info) == vector-infos
Expand Down Expand Up @@ -114,62 +117,48 @@ def test_examples():
assert data_equivalence(vector_infos, expected_vector_infos)


@pytest.mark.parametrize("vectorization_mode", ["async", "sync"])
def test_vector_env_info(
vectorization_mode: str,
env_id: str = "CartPole-v1",
num_envs: int = 3,
env_steps: int = 50,
seed: int = 123,
):
"""Test vector environment info for different vectorization modes."""
env = gym.make_vec(
env_id,
num_envs=num_envs,
vectorization_mode=vectorization_mode,
class ReturnInfoEnv(gym.Env):
def __init__(self, infos):
self.observation_space = Box(0, 1)
self.action_space = Box(0, 1)

self.infos = infos

def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
return self.observation_space.sample(), self.infos[0]

def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
return self.observation_space.sample(), 0, True, False, self.infos[1]


@pytest.mark.parametrize("vectorizer", [AsyncVectorEnv, SyncVectorEnv])
def test_vectorizers(vectorizer):
vec_env = vectorizer(
[
lambda: ReturnInfoEnv([{"a": 1}, {"c": np.array([1, 2])}]),
lambda: ReturnInfoEnv([{"a": 2, "b": 3}, {"c": np.array([3, 4])}]),
]
)
env.reset(seed=seed)
for _ in range(env_steps):
env.action_space.seed(seed)
action = env.action_space.sample()
_, _, terminations, truncations, infos = env.step(action)
if any(terminations) or any(truncations):
assert len(infos["final_observation"]) == num_envs
assert len(infos["_final_observation"]) == num_envs

assert isinstance(infos["final_observation"], np.ndarray)
assert isinstance(infos["_final_observation"], np.ndarray)

for i, (terminated, truncated) in enumerate(zip(terminations, truncations)):
if terminated or truncated:
assert infos["_final_observation"][i]
else:
assert not infos["_final_observation"][i]
assert infos["final_observation"][i] is None


@pytest.mark.parametrize("concurrent_ends", [1, 2, 3])
def test_vector_env_info_concurrent_termination(
concurrent_ends: int,
env_id: str = "CartPole-v1",
num_envs: int = 3,
env_steps: int = 50,
seed: int = 123,
):
"""Test the vector environment information works with concurrent termination."""
# envs that need to terminate together will have the same action
actions = [0] * concurrent_ends + [1] * (num_envs - concurrent_ends)
envs = [make_env(env_id, seed) for _ in range(num_envs)]
envs = SyncVectorEnv(envs)

for _ in range(env_steps):
_, _, terminations, truncations, infos = envs.step(actions)
if any(terminations) or any(truncations):
for i, (terminated, truncated) in enumerate(zip(terminations, truncations)):
if i < concurrent_ends:
assert terminated or truncated
assert infos["_final_observation"][i]
else:
assert not infos["_final_observation"][i]
assert infos["final_observation"][i] is None
return

reset_expected_infos = {
"a": np.array([1, 2]),
"b": np.array([0, 3]),
"_a": np.array([True, True]),
"_b": np.array([False, True]),
}
step_expected_infos = {
"c": np.array([[1, 2], [3, 4]]),
"_c": np.array([True, True]),
}

_, reset_info = vec_env.reset()
assert data_equivalence(reset_info, reset_expected_infos)
_, _, _, _, step_info = vec_env.step(vec_env.action_space.sample())
assert data_equivalence(step_info, step_expected_infos)
Loading

0 comments on commit a85a6e0

Please sign in to comment.