Skip to content

Commit

Permalink
Add check doctest to CI and fixed existing errors (#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
vcharraut authored Jan 20, 2023
1 parent 1551b89 commit b4caf9d
Show file tree
Hide file tree
Showing 45 changed files with 321 additions and 222 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ jobs:
--tag gymnasium-all-docker .
- name: Run tests
run: docker run gymnasium-all-docker pytest tests/*
- name: Run doctest
run: docker run gymnasium-all-docker pytest --doctest-modules gymnasium/

build-necessary:
runs-on:
Expand Down
33 changes: 29 additions & 4 deletions gymnasium/envs/phys2d/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,50 @@ class CartPoleFunctional(
>>> import jax
>>> import jax.numpy as jnp
>>> from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
>>> key = jax.random.PRNGKey(0)
>>> env = CartPole({"x_init": 0.5})
>>> env = CartPoleFunctional({"x_init": 0.5})
>>> state = env.initial(key)
>>> print(state)
>>> print(env.step(state, 0))
[ 0.46532142 -0.27484107 0.13302994 -0.20361817]
>>> print(env.transition(state, 0))
[ 0.4598246 -0.6357784 0.12895757 0.1278053 ]
>>> env.transform(jax.jit)
>>> state = env.initial(key)
>>> print(state)
>>> print(env.step(state, 0))
[ 0.46532142 -0.27484107 0.13302994 -0.20361817]
>>> print(env.transition(state, 0))
[ 0.4598246 -0.6357784 0.12895757 0.12780523]
>>> vkey = jax.random.split(key, 10)
>>> env.transform(jax.vmap)
>>> vstate = env.initial(vkey)
>>> print(vstate)
>>> print(env.step(vstate, jnp.array([0 for _ in range(10)])))
[[ 0.25117755 -0.03159595 0.09428263 0.12404168]
[ 0.231457 0.41420317 -0.13484478 0.29151905]
[-0.11706758 -0.37130308 0.13587534 0.33141208]
[-0.4613737 0.36557996 0.3950702 0.3639989 ]
[-0.14707637 -0.34273267 -0.32374108 -0.48110402]
[-0.45774353 0.3633288 -0.3157575 -0.03586268]
[ 0.37344885 -0.279778 -0.33894253 0.07415426]
[-0.20234215 0.39775252 -0.2556088 0.32877135]
[-0.2572986 -0.29943776 -0.45600426 -0.35740316]
[ 0.05436695 0.35021234 -0.36484408 0.2805779 ]]
>>> print(env.transition(vstate, jnp.array([0 for _ in range(10)])))
[[ 0.25054562 -0.38763174 0.09676346 0.4448946 ]
[ 0.23974106 0.09849604 -0.1290144 0.5390002 ]
[-0.12449364 -0.7323911 0.14250359 0.6634313 ]
[-0.45406207 -0.01028753 0.4023502 0.7505522 ]
[-0.15393102 -0.6168968 -0.33336315 -0.30407968]
[-0.45047694 0.08870795 -0.31647477 0.14311607]
[ 0.36785328 -0.54895645 -0.33745944 0.24393772]
[-0.19438711 0.10855066 -0.24903338 0.5316877 ]
[-0.26328734 -0.5420943 -0.46315232 -0.2344252 ]
[ 0.06137119 0.08665388 -0.35923252 0.4403924 ]]
"""

gravity = 9.8
Expand Down
24 changes: 14 additions & 10 deletions gymnasium/experimental/wrappers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,17 @@ class OrderEnforcingV0(gym.Wrapper):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Example:
>>> from gymnasium.envs.classic_control import CartPoleEnv
>>> env = CartPoleEnv()
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import OrderEnforcingV0
>>> env = gym.make("CartPole-v1", render_mode="human")
>>> env = OrderEnforcingV0(env)
>>> env.step(0)
ResetNeeded: Cannot call env.step() before calling env.reset()
>>> env.step(0) # doctest: +SKIP
gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset()
>>> env.render() # doctest: +SKIP
gymnasium.error.ResetNeeded('Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.')
>>> _ = env.reset()
>>> env.render()
ResetNeeded: Cannot call env.render() before calling env.reset()
>>> env.reset()
>>> env.render()
>>> env.step(0)
>>> _ = env.step(0)
"""

def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False):
Expand Down Expand Up @@ -185,7 +186,6 @@ class RecordEpisodeStatisticsV0(gym.Wrapper):
After the completion of an episode, ``info`` will look like this::
>>> info = {
... ...
... "episode": {
... "r": "<cumulative reward>",
... "l": "<episode length>",
Expand All @@ -196,7 +196,10 @@ class RecordEpisodeStatisticsV0(gym.Wrapper):
For a vectorized environments the output will be in the form of::
>>> infos = {
... ...
... "final_observation": "<array of length num-envs>",
... "_final_observation": "<boolean array of length num-envs>",
... "final_info": "<array of length num-envs>",
... "_final_info": "<boolean array of length num-envs>",
... "episode": {
... "r": "<array of cumulative reward>",
... "l": "<array of episode length>",
Expand All @@ -205,6 +208,7 @@ class RecordEpisodeStatisticsV0(gym.Wrapper):
... "_episode": "<boolean array of length num-envs>"
... }
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
Expand Down
19 changes: 11 additions & 8 deletions gymnasium/experimental/wrappers/lambda_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ class ClipActionV0(LambdaActionV0):
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import ClipActionV0
>>> import numpy as np
>>> env = gym.make('BipedalWalker-v3', disable_env_checker=True)
>>> env = gym.make("Hopper-v4", disable_env_checker=True)
>>> env = ClipActionV0(env)
>>> env.action_space
Box(-1.0, 1.0, (4,), float32)
>>> env.step(np.array([5.0, 2.0, -10.0, 0.0]))
# Executes the action np.array([1.0, 1.0, -1.0, 0]) in the base environment
Box(-inf, inf, (3,), float32)
>>> _ = env.reset(seed=42)
>>> _ = env.step(np.array([5.0, -2.0, 0.0]))
... # Executes the action np.array([1.0, -1.0, 0]) in the base environment
"""

def __init__(self, env: gym.Env):
Expand Down Expand Up @@ -89,13 +91,14 @@ class RescaleActionV0(LambdaActionV0):
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import RescaleActionV0
>>> import numpy as np
>>> env = gym.make('BipedalWalker-v3', disable_env_checker=True)
>>> env = gym.make("Hopper-v4", disable_env_checker=True)
>>> _ = env.reset(seed=42)
>>> obs, _, _, _, _ = env.step(np.array([1,1,1,1]))
>>> obs, _, _, _, _ = env.step(np.array([1,1,1]))
>>> _ = env.reset(seed=42)
>>> min_action = -0.5
>>> max_action = np.array([0.0, 0.5, 1.0, 0.75])
>>> max_action = np.array([0.0, 0.5, 0.75])
>>> wrapped_env = RescaleActionV0(env, min_action=min_action, max_action=max_action)
>>> wrapped_env_obs, _, _, _, _ = wrapped_env.step(max_action)
>>> np.alltrue(obs == wrapped_env_obs)
Expand All @@ -122,7 +125,7 @@ def __init__(

if not isinstance(min_action, np.ndarray):
assert np.issubdtype(type(min_action), np.integer) or np.issubdtype(
type(max_action), np.floating
type(min_action), np.floating
)
min_action = np.full(env.action_space.shape, min_action)

Expand Down
44 changes: 27 additions & 17 deletions gymnasium/experimental/wrappers/lambda_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ class LambdaObservationV0(gym.ObservationWrapper):
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import LambdaObservationV0
>>> import numpy as np
>>> env = gym.make('CartPole-v1')
>>> env = LambdaObservationV0(env, lambda obs: obs + 0.1 * np.random.random(obs.shape))
>>> env.reset()
array([-0.08319338, 0.04635121, -0.07394746, 0.20877492])
>>> np.random.seed(0)
>>> env = gym.make("CartPole-v1")
>>> env = LambdaObservationV0(env, lambda obs: obs + 0.1 * np.random.random(obs.shape), env.observation_space)
>>> env.reset(seed=42) # doctest: +SKIP
(array([ 0.06199517, 0.0511615 , -0.04432538, 0.02694618]), {})
"""

def __init__(
Expand Down Expand Up @@ -75,17 +77,18 @@ class FilterObservationV0(LambdaObservationV0):
Example:
>>> import gymnasium as gym
>>> env = gym.wrappers.TransformObservation(
... gym.make('CartPole-v1'), lambda obs: {'obs': obs, 'time': 0}
... )
>>> from gymnasium.wrappers import TransformObservation
>>> from gymnasium.experimental.wrappers import FilterObservationV0
>>> env = gym.make("CartPole-v1")
>>> env = gym.wrappers.TransformObservation(env, lambda obs: {'obs': obs, 'time': 0})
>>> env.observation_space = gym.spaces.Dict(obs=env.observation_space, time=gym.spaces.Discrete(1))
>>> env.reset()
{'obs': array([-0.00067088, -0.01860439, 0.04772898, -0.01911527], dtype=float32), 'time': 0}
>>> env.reset(seed=42)
({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': 0}, {})
>>> env = FilterObservationV0(env, filter_keys=['time'])
>>> env.reset()
{'obs': array([ 0.04560107, 0.04466959, -0.0328232 , -0.02367178], dtype=float32)}
>>> env.reset(seed=42)
({'time': 0}, {})
>>> env.step(0)
({'obs': array([ 0.04649447, -0.14996664, -0.03329664, 0.25847703], dtype=float32)}, 1.0, False, {})
({'time': 0}, 1.0, False, False, {})
"""

def __init__(self, env: gym.Env, filter_keys: Sequence[str | int]):
Expand Down Expand Up @@ -171,13 +174,14 @@ class FlattenObservationV0(LambdaObservationV0):
Example:
>>> import gymnasium as gym
>>> env = gym.make('CarRacing-v1')
>>> from gymnasium.experimental.wrappers import FlattenObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> env = FlattenObservationV0(env)
>>> env.observation_space.shape
(27648,)
>>> obs, info = env.reset()
>>> obs, _ = env.reset()
>>> obs.shape
(27648,)
"""
Expand All @@ -198,7 +202,8 @@ class GrayscaleObservationV0(LambdaObservationV0):
Example:
>>> import gymnasium as gym
>>> env = gym.make("CarRacing-v1")
>>> from gymnasium.experimental.wrappers import GrayscaleObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> grayscale_env = GrayscaleObservationV0(env)
Expand Down Expand Up @@ -258,6 +263,7 @@ class ResizeObservationV0(LambdaObservationV0):
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import ResizeObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
Expand Down Expand Up @@ -303,7 +309,8 @@ class ReshapeObservationV0(LambdaObservationV0):
Example:
>>> import gymnasium as gym
>>> env = gym.make("CarRacing-v1")
>>> from gymnasium.experimental.wrappers import ReshapeObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> reshape_env = ReshapeObservationV0(env, (24, 4, 96, 1, 3))
Expand Down Expand Up @@ -335,11 +342,14 @@ class RescaleObservationV0(LambdaObservationV0):
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import RescaleObservationV0
>>> env = gym.make("Pendulum-v1")
>>> env.observation_space
Box([-1. -1. -8.], [1. 1. 8.], (3,), float32)
>>> env = RescaleObservationV0(env, np.array([-2, -1, -10]), np.array([1, 0, 1]))
Box([-2. -1. -10.], [1. 0. 1.], (3,), float32)
>>> env.observation_space
Box([ -2. -1. -10.], [1. 0. 1.], (3,), float32)
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/experimental/wrappers/lambda_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class ClipRewardV0(LambdaRewardV0):
>>> from gymnasium.experimental.wrappers import ClipRewardV0
>>> env = gym.make("CartPole-v1")
>>> env = ClipRewardV0(env, 0, 0.5)
>>> env.reset()
>>> _ = env.reset()
>>> _, rew, _, _, _ = env.step(1)
>>> rew
0.5
Expand Down
14 changes: 8 additions & 6 deletions gymnasium/experimental/wrappers/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,26 +288,28 @@ class HumanRenderingV0(gym.Wrapper):
The ``render_mode`` of the wrapped environment must be either ``'rgb_array'`` or ``'rgb_array_list'``.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import HumanRenderingV0
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array")
>>> wrapped = HumanRenderingV0(env)
>>> wrapped.reset() # This will start rendering to the screen
>>> obs, _ = wrapped.reset() # This will start rendering to the screen
The wrapper can also be applied directly when the environment is instantiated, simply by passing
``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not
implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``).
Example:
>>> env = gym.make("NoNativeRendering-v2", render_mode="human") # NoNativeRendering-v0 doesn't implement human-rendering natively
>>> env.reset() # This will start rendering to the screen
>>> env = gym.make("CartPoleJax-v1", render_mode="human") # CartPoleJax-v1 doesn't implement human-rendering natively
>>> obs, _ = env.reset() # This will start rendering to the screen
Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method
will always return an empty list:
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list")
>>> wrapped = HumanRenderingV0(env)
>>> wrapped.reset()
>>> env.render()
[] # env.render() will always return an empty list!
>>> obs, _ = wrapped.reset()
>>> env.render() # env.render() will always return an empty list!
[]
"""

Expand Down
32 changes: 18 additions & 14 deletions gymnasium/experimental/wrappers/stateful_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,27 @@ class TimeAwareObservationV0(gym.ObservationWrapper):
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import TimeAwareObservationV0
>>> env = gym.make('CartPole-v1')
>>> env = gym.make("CartPole-v1")
>>> env = TimeAwareObservationV0(env)
>>> env.observation_space
Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32))
>>> _ = env.reset()
Dict('obs': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), 'time': Box(0.0, 1.0, (1,), float32))
>>> _ = env.reset(seed=42)
>>> _ = env.action_space.seed(42)
>>> env.step(env.action_space.sample())[0]
OrderedDict([('obs',
... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)),
... ('time', array([0.002]))])
{'obs': array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476], dtype=float32), 'time': 0.002}
Flatten observation space example:
>>> env = gym.make('CartPole-v1')
>>> env = gym.make("CartPole-v1")
>>> env = TimeAwareObservationV0(env, flatten=True)
>>> env.observation_space
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32)
>>> _ = env.reset()
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38
0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 1.0000000e+00], (5,), float32)
>>> _ = env.reset(seed=42)
>>> _ = env.action_space.seed(42)
>>> env.step(env.action_space.sample())[0]
array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32)
array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476, 0.002 ],
dtype=float32)
"""

def __init__(
Expand Down Expand Up @@ -224,11 +227,12 @@ class FrameStackObservationV0(gym.Wrapper):
Example:
>>> import gymnasium as gym
>>> env = gym.make('CarRacing-v1')
>>> env = FrameStack(env, 4)
>>> from gymnasium.experimental.wrappers import FrameStackObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env = FrameStackObservationV0(env, 4)
>>> env.observation_space
Box(4, 96, 96, 3)
>>> obs = env.reset()
Box(0, 255, (4, 96, 96, 3), uint8)
>>> obs, _ = env.reset()
>>> obs.shape
(4, 96, 96, 3)
"""
Expand Down
4 changes: 2 additions & 2 deletions gymnasium/spaces/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ class Box(Space[NDArray[Any]]):
* Identical bound for each dimension::
>>> Box(low=-1.0, high=2.0, shape=(3, 4), dtype=np.float32)
Box(3, 4)
Box(-1.0, 2.0, (3, 4), float32)
* Independent bound for each dimension::
>>> Box(low=np.array([-1.0, -2.0]), high=np.array([2.0, 4.0]), dtype=np.float32)
Box(2,)
Box([-1. -2.], [2. 4.], (2,), float32)
"""

def __init__(
Expand Down
Loading

0 comments on commit b4caf9d

Please sign in to comment.