Skip to content

Commit

Permalink
Change the functional API to include explicit params (#818)
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon authored Dec 17, 2023
1 parent e1e3c45 commit e7e80a9
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 160 deletions.
1 change: 0 additions & 1 deletion gymnasium/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@
register(
id="tabular/Blackjack-v0",
entry_point="gymnasium.envs.tabular.blackjack:BlackJackJaxEnv",
kwargs={"sutton_and_barto": True, "natural": False},
)

register(
Expand Down
162 changes: 71 additions & 91 deletions gymnasium/envs/phys2d/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from flax import struct
from jax.random import PRNGKey

import gymnasium as gym
Expand All @@ -18,135 +19,109 @@
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821


@struct.dataclass
class CartPoleParams:
"""Parameters for the jax CartPole environment."""

gravity: float = 9.8
masscart: float = 1.0
masspole: float = 0.1
total_mass: float = masspole + masscart
length: float = 0.5
polemass_length: float = masspole + length
force_mag: float = 10.0
tau: float = 0.02
theta_threshold_radians: float = 12 * 2 * np.pi / 360
x_threshold: float = 2.4
x_init: float = 0.05

screen_width: int = 600
screen_height: int = 400


class CartPoleFunctional(
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType]
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType, CartPoleParams]
):
"""Cartpole but in jax and functional.
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
>>> key = jax.random.PRNGKey(0)
>>> env = CartPoleFunctional({"x_init": 0.5})
>>> state = env.initial(key)
>>> print(state)
[ 0.46532142 -0.27484107 0.13302994 -0.20361817]
>>> print(env.transition(state, 0))
[ 0.4598246 -0.6357784 0.12895757 0.1278053 ]
>>> env.transform(jax.jit)
>>> state = env.initial(key)
>>> print(state)
[ 0.46532142 -0.27484107 0.13302994 -0.20361817]
>>> print(env.transition(state, 0))
[ 0.4598246 -0.6357784 0.12895757 0.12780523]
>>> vkey = jax.random.split(key, 10)
>>> env.transform(jax.vmap)
>>> vstate = env.initial(vkey)
>>> print(vstate)
[[ 0.25117755 -0.03159595 0.09428263 0.12404168]
[ 0.231457 0.41420317 -0.13484478 0.29151905]
[-0.11706758 -0.37130308 0.13587534 0.33141208]
[-0.4613737 0.36557996 0.3950702 0.3639989 ]
[-0.14707637 -0.34273267 -0.32374108 -0.48110402]
[-0.45774353 0.3633288 -0.3157575 -0.03586268]
[ 0.37344885 -0.279778 -0.33894253 0.07415426]
[-0.20234215 0.39775252 -0.2556088 0.32877135]
[-0.2572986 -0.29943776 -0.45600426 -0.35740316]
[ 0.05436695 0.35021234 -0.36484408 0.2805779 ]]
>>> print(env.transition(vstate, jnp.array([0 for _ in range(10)])))
[[ 0.25054562 -0.38763174 0.09676346 0.4448946 ]
[ 0.23974106 0.09849604 -0.1290144 0.5390002 ]
[-0.12449364 -0.7323911 0.14250359 0.6634313 ]
[-0.45406207 -0.01028753 0.4023502 0.7505522 ]
[-0.15393102 -0.6168968 -0.33336315 -0.30407968]
[-0.45047694 0.08870795 -0.31647477 0.14311607]
[ 0.36785328 -0.54895645 -0.33745944 0.24393772]
[-0.19438711 0.10855066 -0.24903338 0.5316877 ]
[-0.26328734 -0.5420943 -0.46315232 -0.2344252 ]
[ 0.06137119 0.08665388 -0.35923252 0.4403924 ]]
"""

gravity = 9.8
masscart = 1.0
masspole = 0.1
total_mass = masspole + masscart
length = 0.5
polemass_length = masspole + length
force_mag = 10.0
tau = 0.02
theta_threshold_radians = 12 * 2 * np.pi / 360
x_threshold = 2.4
x_init = 0.05

screen_width = 600
screen_height = 400
"""Cartpole but in jax and functional."""

observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32)
action_space = gym.spaces.Discrete(2)

def initial(self, rng: PRNGKey):
def initial(self, rng: PRNGKey, params: CartPoleParams = CartPoleParams):
"""Initial state generation."""
return jax.random.uniform(
key=rng, minval=-self.x_init, maxval=self.x_init, shape=(4,)
key=rng, minval=-params.x_init, maxval=params.x_init, shape=(4,)
)

def transition(
self, state: jax.Array, action: int | jax.Array, rng: None = None
self,
state: jax.Array,
action: int | jax.Array,
rng: None = None,
params: CartPoleParams = CartPoleParams,
) -> StateType:
"""Cartpole transition."""
x, x_dot, theta, theta_dot = state
force = jnp.sign(action - 0.5) * self.force_mag
force = jnp.sign(action - 0.5) * params.force_mag
costheta = jnp.cos(theta)
sintheta = jnp.sin(theta)

# For the interested reader:
# https://coneural.org/florian/papers/05_cart_pole.pdf
temp = (
force + self.polemass_length * theta_dot**2 * sintheta
) / self.total_mass
thetaacc = (self.gravity * sintheta - costheta * temp) / (
self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass)
force + params.polemass_length * theta_dot**2 * sintheta
) / params.total_mass
thetaacc = (params.gravity * sintheta - costheta * temp) / (
params.length
* (4.0 / 3.0 - params.masspole * costheta**2 / params.total_mass)
)
xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
xacc = temp - params.polemass_length * thetaacc * costheta / params.total_mass

x = x + self.tau * x_dot
x_dot = x_dot + self.tau * xacc
theta = theta + self.tau * theta_dot
theta_dot = theta_dot + self.tau * thetaacc
x = x + params.tau * x_dot
x_dot = x_dot + params.tau * xacc
theta = theta + params.tau * theta_dot
theta_dot = theta_dot + params.tau * thetaacc

state = jnp.array((x, x_dot, theta, theta_dot), dtype=jnp.float32)

return state

def observation(self, state: jax.Array) -> jax.Array:
def observation(
self, state: jax.Array, params: CartPoleParams = CartPoleParams
) -> jax.Array:
"""Cartpole observation."""
return state

def terminal(self, state: jax.Array) -> jax.Array:
def terminal(
self, state: jax.Array, params: CartPoleParams = CartPoleParams
) -> jax.Array:
"""Checks if the state is terminal."""
x, _, theta, _ = state

terminated = (
(x < -self.x_threshold)
| (x > self.x_threshold)
| (theta < -self.theta_threshold_radians)
| (theta > self.theta_threshold_radians)
(x < -params.x_threshold)
| (x > params.x_threshold)
| (theta < -params.theta_threshold_radians)
| (theta > params.theta_threshold_radians)
)

return terminated

def reward(
self, state: StateType, action: ActType, next_state: StateType
self,
state: StateType,
action: ActType,
next_state: StateType,
params: CartPoleParams = CartPoleParams,
) -> jax.Array:
"""Computes the reward for the state transition using the action."""
x, _, theta, _ = state

terminated = (
(x < -self.x_threshold)
| (x > self.x_threshold)
| (theta < -self.theta_threshold_radians)
| (theta > self.theta_threshold_radians)
(x < -params.x_threshold)
| (x > params.x_threshold)
| (theta < -params.theta_threshold_radians)
| (theta > params.theta_threshold_radians)
)

reward = jax.lax.cond(terminated, lambda: 0.0, lambda: 1.0)
Expand All @@ -156,6 +131,7 @@ def render_image(
self,
state: StateType,
render_state: RenderStateType,
params: CartPoleParams = CartPoleParams,
) -> tuple[RenderStateType, np.ndarray]:
"""Renders an image of the state using the render state."""
try:
Expand All @@ -167,21 +143,21 @@ def render_image(
) from e
screen, clock = render_state

world_width = self.x_threshold * 2
scale = self.screen_width / world_width
world_width = params.x_threshold * 2
scale = params.screen_width / world_width
polewidth = 10.0
polelen = scale * (2 * self.length)
polelen = scale * (2 * params.length)
cartwidth = 50.0
cartheight = 30.0

x = state

surf = pygame.Surface((self.screen_width, self.screen_height))
surf = pygame.Surface((params.screen_width, params.screen_height))
surf.fill((255, 255, 255))

l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
axleoffset = cartheight / 4.0
cartx = x[0] * scale + self.screen_width / 2.0 # MIDDLE OF CART
cartx = x[0] * scale + params.screen_width / 2.0 # MIDDLE OF CART
carty = 100 # TOP OF CART
cart_coords = [(l, b), (l, t), (r, t), (r, b)]
cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
Expand Down Expand Up @@ -218,7 +194,7 @@ def render_image(
(129, 132, 203),
)

gfxdraw.hline(surf, 0, self.screen_width, carty, (0, 0, 0))
gfxdraw.hline(surf, 0, params.screen_width, carty, (0, 0, 0))

surf = pygame.transform.flip(surf, False, True)
screen.blit(surf, (0, 0))
Expand Down Expand Up @@ -255,6 +231,10 @@ def render_close(self, render_state: RenderStateType) -> None:
pygame.display.quit()
pygame.quit()

def get_default_params(self, **kwargs) -> CartPoleParams:
"""Returns the default parameters for the environment."""
return CartPoleParams(**kwargs)


class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based implementation of the CartPole environment."""
Expand Down
Loading

0 comments on commit e7e80a9

Please sign in to comment.