From ce8f61a7b1e3cbc65f7ebfc4e7e31af2722a0ad1 Mon Sep 17 00:00:00 2001 From: ariel Date: Wed, 6 Dec 2023 12:06:36 -0500 Subject: [PATCH] Remove the code example from cartpole - the API is more general than cartpole now --- gymnasium/envs/phys2d/cartpole.py | 47 +------------------------------ 1 file changed, 1 insertion(+), 46 deletions(-) diff --git a/gymnasium/envs/phys2d/cartpole.py b/gymnasium/envs/phys2d/cartpole.py index 9f8cfbcc6..2842e857a 100644 --- a/gymnasium/envs/phys2d/cartpole.py +++ b/gymnasium/envs/phys2d/cartpole.py @@ -42,52 +42,7 @@ class CartPoleParams: class CartPoleFunctional( FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType, CartPoleParams] ): - """Cartpole but in jax and functional. - - Example: - >>> import jax - >>> import jax.numpy as jnp - >>> from gymnasium.envs.phys2d.cartpole import CartPoleFunctional, CartPoleParams - >>> key = jax.random.PRNGKey(0) - >>> params = CartPoleParams(x_init=0.5) - >>> env = CartPoleFunctional() - >>> state = env.initial(key, params=params) - >>> print(state) - [ 0.46532142 -0.27484107 0.13302994 -0.20361817] - >>> print(env.transition(state, 0, params=params)) # doctest: +SKI - [ 0.4598246 -0.6357784 0.12895757 0.1278053 ] - >>> env.transform(jax.jit) - >>> state = env.initial(key, params=params) - >>> print(state) # doctest: +SKI - [ 0.46532142 -0.27484107 0.13302994 -0.20361817] - >>> print(env.transition(state, 0, params=params)) # doctest: +SKIP - [ 0.4598246 -0.6357784 0.12895757 0.1278053 ] - >>> vkey = jax.random.split(key, 10) - >>> env.transform(jax.vmap) - >>> vstate = env.initial(vkey, params=params) - >>> print(vstate) - [[ 0.25117755 -0.03159595 0.09428263 0.12404168] - [ 0.231457 0.41420317 -0.13484478 0.29151905] - [-0.11706758 -0.37130308 0.13587534 0.33141208] - [-0.4613737 0.36557996 0.3950702 0.3639989 ] - [-0.14707637 -0.34273267 -0.32374108 -0.48110402] - [-0.45774353 0.3633288 -0.3157575 -0.03586268] - [ 0.37344885 -0.279778 -0.33894253 0.07415426] - [-0.20234215 0.39775252 -0.2556088 0.32877135] - [-0.2572986 -0.29943776 -0.45600426 -0.35740316] - [ 0.05436695 0.35021234 -0.36484408 0.2805779 ]] - >>> print(env.transition(vstate, jnp.array([0 for _ in range(10)]), params=params)) - [[ 0.25054562 -0.38763174 0.09676346 0.4448946 ] - [ 0.23974106 0.09849604 -0.1290144 0.5390002 ] - [-0.12449364 -0.7323911 0.14250359 0.6634313 ] - [-0.45406207 -0.01028753 0.4023502 0.7505522 ] - [-0.15393102 -0.6168968 -0.33336315 -0.30407968] - [-0.45047694 0.08870795 -0.31647477 0.14311607] - [ 0.36785328 -0.54895645 -0.33745944 0.24393772] - [-0.19438711 0.10855066 -0.24903338 0.5316877 ] - [-0.26328734 -0.5420943 -0.46315232 -0.2344252 ] - [ 0.06137119 0.08665388 -0.35923252 0.4403924 ]] - """ + """Cartpole but in jax and functional.""" observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32) action_space = gym.spaces.Discrete(2)