Skip to content

Commit

Permalink
WiP, changes to jax funcenv params handling
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Dec 5, 2023
1 parent 6d492ea commit 4be15eb
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 41 deletions.
86 changes: 55 additions & 31 deletions gymnasium/envs/phys2d/cartpole.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Implementation of a Jax-accelerated cartpole environment."""
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Tuple

import jax
import jax.numpy as jnp
import numpy as np
from jax.random import PRNGKey
from flax import struct

import gymnasium as gym
from gymnasium.error import DependencyNotInstalled
Expand All @@ -17,9 +19,26 @@

RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821

@struct.dataclass
class CartPoleParams:
gravity: float = 9.8
masscart: float = 1.0
masspole: float = 0.1
total_mass: float = masspole + masscart
length: float = 0.5
polemass_length: float = masspole + length
force_mag: float = 10.0
tau: float = 0.02
theta_threshold_radians: float = 12 * 2 * np.pi / 360
x_threshold: float = 2.4
x_init: float = 0.05

screen_width: int = 600
screen_height: int = 400


class CartPoleFunctional(
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType]
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType, CartPoleParams]
):
"""Cartpole but in jax and functional.
Expand Down Expand Up @@ -85,68 +104,68 @@ class CartPoleFunctional(
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32)
action_space = gym.spaces.Discrete(2)

def initial(self, rng: PRNGKey):
def initial(self, rng: PRNGKey, params: CartPoleParams | None = CartPoleParams):
"""Initial state generation."""
return jax.random.uniform(
key=rng, minval=-self.x_init, maxval=self.x_init, shape=(4,)
key=rng, minval=-params.x_init, maxval=params.x_init, shape=(4,)
)

def transition(
self, state: jax.Array, action: int | jax.Array, rng: None = None
self, state: jax.Array, action: int | jax.Array, rng: None = None, params: CartPoleParams | None = CartPoleParams
) -> StateType:
"""Cartpole transition."""
x, x_dot, theta, theta_dot = state
force = jnp.sign(action - 0.5) * self.force_mag
force = jnp.sign(action - 0.5) * params.force_mag
costheta = jnp.cos(theta)
sintheta = jnp.sin(theta)

# For the interested reader:
# https://coneural.org/florian/papers/05_cart_pole.pdf
temp = (
force + self.polemass_length * theta_dot**2 * sintheta
) / self.total_mass
thetaacc = (self.gravity * sintheta - costheta * temp) / (
self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass)
force + params.polemass_length * theta_dot**2 * sintheta
) / params.total_mass
thetaacc = (params.gravity * sintheta - costheta * temp) / (
params.length * (4.0 / 3.0 - params.masspole * costheta**2 / params.total_mass)
)
xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
xacc = temp - params.polemass_length * thetaacc * costheta / params.total_mass

x = x + self.tau * x_dot
x_dot = x_dot + self.tau * xacc
theta = theta + self.tau * theta_dot
theta_dot = theta_dot + self.tau * thetaacc
x = x + params.tau * x_dot
x_dot = x_dot + params.tau * xacc
theta = theta + params.tau * theta_dot
theta_dot = theta_dot + params.tau * thetaacc

state = jnp.array((x, x_dot, theta, theta_dot), dtype=jnp.float32)

return state

def observation(self, state: jax.Array) -> jax.Array:
def observation(self, state: jax.Array, params: CartPoleParams | None = CartPoleParams) -> jax.Array:
"""Cartpole observation."""
return state

def terminal(self, state: jax.Array) -> jax.Array:
def terminal(self, state: jax.Array, params: CartPoleParams | None = CartPoleParams) -> jax.Array:
"""Checks if the state is terminal."""
x, _, theta, _ = state

terminated = (
(x < -self.x_threshold)
| (x > self.x_threshold)
| (theta < -self.theta_threshold_radians)
| (theta > self.theta_threshold_radians)
(x < -params.x_threshold)
| (x > params.x_threshold)
| (theta < -params.theta_threshold_radians)
| (theta > params.theta_threshold_radians)
)

return terminated

def reward(
self, state: StateType, action: ActType, next_state: StateType
self, state: StateType, action: ActType, next_state: StateType, params: CartPoleParams | None = CartPoleParams
) -> jax.Array:
"""Computes the reward for the state transition using the action."""
x, _, theta, _ = state

terminated = (
(x < -self.x_threshold)
| (x > self.x_threshold)
| (theta < -self.theta_threshold_radians)
| (theta > self.theta_threshold_radians)
(x < -params.x_threshold)
| (x > params.x_threshold)
| (theta < -params.theta_threshold_radians)
| (theta > params.theta_threshold_radians)
)

reward = jax.lax.cond(terminated, lambda: 0.0, lambda: 1.0)
Expand All @@ -156,6 +175,7 @@ def render_image(
self,
state: StateType,
render_state: RenderStateType,
params: CartPoleParams | None = CartPoleParams,
) -> tuple[RenderStateType, np.ndarray]:
"""Renders an image of the state using the render state."""
try:
Expand All @@ -167,21 +187,21 @@ def render_image(
) from e
screen, clock = render_state

world_width = self.x_threshold * 2
scale = self.screen_width / world_width
world_width = params.x_threshold * 2
scale = params.screen_width / world_width
polewidth = 10.0
polelen = scale * (2 * self.length)
polelen = scale * (2 * params.length)
cartwidth = 50.0
cartheight = 30.0

x = state

surf = pygame.Surface((self.screen_width, self.screen_height))
surf = pygame.Surface((params.screen_width, params.screen_height))
surf.fill((255, 255, 255))

l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
axleoffset = cartheight / 4.0
cartx = x[0] * scale + self.screen_width / 2.0 # MIDDLE OF CART
cartx = x[0] * scale + params.screen_width / 2.0 # MIDDLE OF CART
carty = 100 # TOP OF CART
cart_coords = [(l, b), (l, t), (r, t), (r, b)]
cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
Expand Down Expand Up @@ -218,7 +238,7 @@ def render_image(
(129, 132, 203),
)

gfxdraw.hline(surf, 0, self.screen_width, carty, (0, 0, 0))
gfxdraw.hline(surf, 0, params.screen_width, carty, (0, 0, 0))

surf = pygame.transform.flip(surf, False, True)
screen.blit(surf, (0, 0))
Expand Down Expand Up @@ -255,6 +275,10 @@ def render_close(self, render_state: RenderStateType) -> None:
pygame.display.quit()
pygame.quit()

def get_default_params(self, **kwargs) -> CartPoleParams:
"""Returns the default parameters for the environment."""
return CartPoleParams(**kwargs)


class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based implementation of the CartPole environment."""
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/envs/phys2d/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


class PendulumFunctional(
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType]
#FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType]
):
"""Pendulum but in jax and functional structure."""

Expand Down
24 changes: 15 additions & 9 deletions gymnasium/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
RewardType = TypeVar("RewardType")
TerminalType = TypeVar("TerminalType")
RenderStateType = TypeVar("RenderStateType")
Params = TypeVar("Params")


class FuncEnv(
Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType]
Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType, Params]
):
"""Base class (template) for functional envs.
Expand Down Expand Up @@ -46,35 +47,36 @@ class FuncEnv(
def __init__(self, options: dict[str, Any] | None = None):
"""Initialize the environment constants."""
self.__dict__.update(options or {})
self.default_params = self.get_default_params()

def initial(self, rng: Any) -> StateType:
def initial(self, rng: Any, params: Params | None = None) -> StateType:
"""Generates the initial state of the environment with a random number generator."""
raise NotImplementedError

def transition(self, state: StateType, action: ActType, rng: Any) -> StateType:
def transition(self, state: StateType, action: ActType, rng: Any, params: Params | None = None) -> StateType:
"""Updates (transitions) the state with an action and random number generator."""
raise NotImplementedError

def observation(self, state: StateType) -> ObsType:
def observation(self, state: StateType, params: Params | None = None) -> ObsType:
"""Generates an observation for a given state of an environment."""
raise NotImplementedError

def reward(
self, state: StateType, action: ActType, next_state: StateType
self, state: StateType, action: ActType, next_state: StateType, params: Params | None = None
) -> RewardType:
"""Computes the reward for a given transition between `state`, `action` to `next_state`."""
raise NotImplementedError

def terminal(self, state: StateType) -> TerminalType:
def terminal(self, state: StateType, params: Params | None = None) -> TerminalType:
"""Returns if the state is a final terminal state."""
raise NotImplementedError

def initial_info(self, state: StateType) -> dict:
def initial_info(self, state: StateType, params: Params | None = None) -> dict:
"""Info dict about a single state."""
return {}

def transition_info(
self, state: StateType, action: ActType, next_state: StateType
self, state: StateType, action: ActType, next_state: StateType, params: Params | None = None
) -> dict:
"""Info dict about a full transition."""
return {}
Expand All @@ -90,7 +92,7 @@ def transform(self, func: Callable[[Callable], Callable]):
self.step_info = func(self.transition_info)

def render_image(
self, state: StateType, render_state: RenderStateType
self, state: StateType, render_state: RenderStateType, params: Params | None = None
) -> tuple[RenderStateType, np.ndarray]:
"""Show the state."""
raise NotImplementedError
Expand All @@ -102,3 +104,7 @@ def render_initialise(self, **kwargs) -> RenderStateType:
def render_close(self, render_state: RenderStateType):
"""Close the render state."""
raise NotImplementedError

def get_default_params(self, **kwargs) -> Params | None:
"""Get the default params."""
return None

0 comments on commit 4be15eb

Please sign in to comment.