Skip to content

Commit

Permalink
Remove the code example from cartpole - the API is more general than …
Browse files Browse the repository at this point in the history
…cartpole now
  • Loading branch information
RedTachyon committed Dec 6, 2023
1 parent a9aea84 commit ce8f61a
Showing 1 changed file with 1 addition and 46 deletions.
47 changes: 1 addition & 46 deletions gymnasium/envs/phys2d/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,52 +42,7 @@ class CartPoleParams:
class CartPoleFunctional(
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType, CartPoleParams]
):
"""Cartpole but in jax and functional.
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> from gymnasium.envs.phys2d.cartpole import CartPoleFunctional, CartPoleParams
>>> key = jax.random.PRNGKey(0)
>>> params = CartPoleParams(x_init=0.5)
>>> env = CartPoleFunctional()
>>> state = env.initial(key, params=params)
>>> print(state)
[ 0.46532142 -0.27484107 0.13302994 -0.20361817]
>>> print(env.transition(state, 0, params=params)) # doctest: +SKI
[ 0.4598246 -0.6357784 0.12895757 0.1278053 ]
>>> env.transform(jax.jit)
>>> state = env.initial(key, params=params)
>>> print(state) # doctest: +SKI
[ 0.46532142 -0.27484107 0.13302994 -0.20361817]
>>> print(env.transition(state, 0, params=params)) # doctest: +SKIP
[ 0.4598246 -0.6357784 0.12895757 0.1278053 ]
>>> vkey = jax.random.split(key, 10)
>>> env.transform(jax.vmap)
>>> vstate = env.initial(vkey, params=params)
>>> print(vstate)
[[ 0.25117755 -0.03159595 0.09428263 0.12404168]
[ 0.231457 0.41420317 -0.13484478 0.29151905]
[-0.11706758 -0.37130308 0.13587534 0.33141208]
[-0.4613737 0.36557996 0.3950702 0.3639989 ]
[-0.14707637 -0.34273267 -0.32374108 -0.48110402]
[-0.45774353 0.3633288 -0.3157575 -0.03586268]
[ 0.37344885 -0.279778 -0.33894253 0.07415426]
[-0.20234215 0.39775252 -0.2556088 0.32877135]
[-0.2572986 -0.29943776 -0.45600426 -0.35740316]
[ 0.05436695 0.35021234 -0.36484408 0.2805779 ]]
>>> print(env.transition(vstate, jnp.array([0 for _ in range(10)]), params=params))
[[ 0.25054562 -0.38763174 0.09676346 0.4448946 ]
[ 0.23974106 0.09849604 -0.1290144 0.5390002 ]
[-0.12449364 -0.7323911 0.14250359 0.6634313 ]
[-0.45406207 -0.01028753 0.4023502 0.7505522 ]
[-0.15393102 -0.6168968 -0.33336315 -0.30407968]
[-0.45047694 0.08870795 -0.31647477 0.14311607]
[ 0.36785328 -0.54895645 -0.33745944 0.24393772]
[-0.19438711 0.10855066 -0.24903338 0.5316877 ]
[-0.26328734 -0.5420943 -0.46315232 -0.2344252 ]
[ 0.06137119 0.08665388 -0.35923252 0.4403924 ]]
"""
"""Cartpole but in jax and functional."""

observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32)
action_space = gym.spaces.Discrete(2)
Expand Down

0 comments on commit ce8f61a

Please sign in to comment.