Skip to content

Commit

Permalink
Try to fix jax pendulum
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Dec 6, 2023
1 parent 4be15eb commit a19ba2f
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 43 deletions.
33 changes: 24 additions & 9 deletions gymnasium/envs/phys2d/cartpole.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Implementation of a Jax-accelerated cartpole environment."""
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Tuple

import jax
import jax.numpy as jnp
import numpy as np
from jax.random import PRNGKey
from flax import struct
from jax.random import PRNGKey

import gymnasium as gym
from gymnasium.error import DependencyNotInstalled
Expand All @@ -19,8 +18,11 @@

RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821


@struct.dataclass
class CartPoleParams:
"""Parameters for the jax CartPole environment."""

gravity: float = 9.8
masscart: float = 1.0
masspole: float = 0.1
Expand Down Expand Up @@ -104,14 +106,18 @@ class CartPoleFunctional(
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32)
action_space = gym.spaces.Discrete(2)

def initial(self, rng: PRNGKey, params: CartPoleParams | None = CartPoleParams):
def initial(self, rng: PRNGKey, params: CartPoleParams = CartPoleParams):
"""Initial state generation."""
return jax.random.uniform(
key=rng, minval=-params.x_init, maxval=params.x_init, shape=(4,)
)

def transition(
self, state: jax.Array, action: int | jax.Array, rng: None = None, params: CartPoleParams | None = CartPoleParams
self,
state: jax.Array,
action: int | jax.Array,
rng: None = None,
params: CartPoleParams = CartPoleParams,
) -> StateType:
"""Cartpole transition."""
x, x_dot, theta, theta_dot = state
Expand All @@ -125,7 +131,8 @@ def transition(
force + params.polemass_length * theta_dot**2 * sintheta
) / params.total_mass
thetaacc = (params.gravity * sintheta - costheta * temp) / (
params.length * (4.0 / 3.0 - params.masspole * costheta**2 / params.total_mass)
params.length
* (4.0 / 3.0 - params.masspole * costheta**2 / params.total_mass)
)
xacc = temp - params.polemass_length * thetaacc * costheta / params.total_mass

Expand All @@ -138,11 +145,15 @@ def transition(

return state

def observation(self, state: jax.Array, params: CartPoleParams | None = CartPoleParams) -> jax.Array:
def observation(
self, state: jax.Array, params: CartPoleParams = CartPoleParams
) -> jax.Array:
"""Cartpole observation."""
return state

def terminal(self, state: jax.Array, params: CartPoleParams | None = CartPoleParams) -> jax.Array:
def terminal(
self, state: jax.Array, params: CartPoleParams = CartPoleParams
) -> jax.Array:
"""Checks if the state is terminal."""
x, _, theta, _ = state

Expand All @@ -156,7 +167,11 @@ def terminal(self, state: jax.Array, params: CartPoleParams | None = CartPolePar
return terminated

def reward(
self, state: StateType, action: ActType, next_state: StateType, params: CartPoleParams | None = CartPoleParams
self,
state: StateType,
action: ActType,
next_state: StateType,
params: CartPoleParams = CartPoleParams,
) -> jax.Array:
"""Computes the reward for the state transition using the action."""
x, _, theta, _ = state
Expand All @@ -175,7 +190,7 @@ def render_image(
self,
state: StateType,
render_state: RenderStateType,
params: CartPoleParams | None = CartPoleParams,
params: CartPoleParams = CartPoleParams,
) -> tuple[RenderStateType, np.ndarray]:
"""Renders an image of the state using the render state."""
try:
Expand Down
85 changes: 56 additions & 29 deletions gymnasium/envs/phys2d/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from flax import struct
from jax.random import PRNGKey

import gymnasium as gym
Expand All @@ -19,57 +20,74 @@
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821


@struct.dataclass
class PendulumParams:
"""Parameters for the jax Pendulum environment."""

max_speed: float = 8.0
dt: float = 0.05
g: float = 10.0
m: float = 1.0
l: float = 1.0
high_x: float = jnp.pi
high_y: float = 1.0
screen_dim: int = 500


class PendulumFunctional(
#FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType]
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType, PendulumParams]
):
"""Pendulum but in jax and functional structure."""

max_speed = 8
max_torque = 2.0
dt = 0.05
g = 10.0
m = 1.0
l = 1.0
high_x = jnp.pi
high_y = 1.0

screen_dim = 500
max_torque: float = 2.0

observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(3,), dtype=np.float32)
action_space = gym.spaces.Box(-max_torque, max_torque, shape=(1,), dtype=np.float32)

def initial(self, rng: PRNGKey):
def initial(self, rng: PRNGKey, params: PendulumParams = PendulumParams):
"""Initial state generation."""
high = jnp.array([self.high_x, self.high_y])
high = jnp.array([params.high_x, params.high_y])
return jax.random.uniform(key=rng, minval=-high, maxval=high, shape=high.shape)

def transition(
self, state: jax.Array, action: int | jax.Array, rng: None = None
self,
state: jax.Array,
action: int | jax.Array,
rng: None = None,
params: PendulumParams = PendulumParams,
) -> jax.Array:
"""Pendulum transition."""
th, thdot = state # th := theta
u = action

g = self.g
m = self.m
l = self.l
dt = self.dt
g = params.g
m = params.m
l = params.l
dt = params.dt

u = jnp.clip(u, -self.max_torque, self.max_torque)[0]

newthdot = thdot + (3 * g / (2 * l) * jnp.sin(th) + 3.0 / (m * l**2) * u) * dt
newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed)
newthdot = jnp.clip(newthdot, -params.max_speed, params.max_speed)
newth = th + newthdot * dt

new_state = jnp.array([newth, newthdot])
return new_state

def observation(self, state: jax.Array) -> jax.Array:
def observation(
self, state: jax.Array, params: PendulumParams = PendulumParams
) -> jax.Array:
"""Generates an observation based on the state."""
theta, thetadot = state
return jnp.array([jnp.cos(theta), jnp.sin(theta), thetadot])

def reward(self, state: StateType, action: ActType, next_state: StateType) -> float:
def reward(
self,
state: StateType,
action: ActType,
next_state: StateType,
params: PendulumParams = PendulumParams,
) -> float:
"""Generates the reward based on the state, action and next state."""
th, thdot = state # th := theta
u = action
Expand All @@ -81,14 +99,17 @@ def reward(self, state: StateType, action: ActType, next_state: StateType) -> fl

return -costs

def terminal(self, state: StateType) -> bool:
def terminal(
self, state: StateType, params: PendulumParams = PendulumParams
) -> bool:
"""Determines if the state is a terminal state."""
return False

def render_image(
self,
state: StateType,
render_state: tuple[pygame.Surface, pygame.time.Clock, float | None], # type: ignore # noqa: F821
render_state: RenderStateType,
params: PendulumParams = PendulumParams,
) -> tuple[RenderStateType, np.ndarray]:
"""Renders an RGB image."""
try:
Expand All @@ -100,12 +121,12 @@ def render_image(
) from e
screen, clock, last_u = render_state

surf = pygame.Surface((self.screen_dim, self.screen_dim))
surf = pygame.Surface((params.screen_dim, params.screen_dim))
surf.fill((255, 255, 255))

bound = 2.2
scale = self.screen_dim / (bound * 2)
offset = self.screen_dim // 2
scale = params.screen_dim / (bound * 2)
offset = params.screen_dim // 2

rod_length = 1 * scale
rod_width = 0.2 * scale
Expand Down Expand Up @@ -149,7 +170,6 @@ def render_image(
),
)

# drawing axle
gfxdraw.aacircle(surf, offset, offset, int(0.05 * scale), (0, 0, 0))
gfxdraw.filled_circle(surf, offset, offset, int(0.05 * scale), (0, 0, 0))

Expand All @@ -161,7 +181,10 @@ def render_image(
)

def render_initialise(
self, screen_width: int = 600, screen_height: int = 400
self,
screen_width: int = 600,
screen_height: int = 400,
params: PendulumParams = PendulumParams,
) -> RenderStateType:
"""Initialises the render state."""
try:
Expand All @@ -177,7 +200,11 @@ def render_initialise(

return screen, clock, None

def render_close(self, render_state: RenderStateType):
def render_close(
self,
render_state: RenderStateType,
params: PendulumParams = PendulumParams,
):
"""Closes the render state."""
try:
import pygame
Expand Down
25 changes: 20 additions & 5 deletions gymnasium/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@


class FuncEnv(
Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType, Params]
Generic[
StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType, Params
]
):
"""Base class (template) for functional envs.
Expand Down Expand Up @@ -53,7 +55,9 @@ def initial(self, rng: Any, params: Params | None = None) -> StateType:
"""Generates the initial state of the environment with a random number generator."""
raise NotImplementedError

def transition(self, state: StateType, action: ActType, rng: Any, params: Params | None = None) -> StateType:
def transition(
self, state: StateType, action: ActType, rng: Any, params: Params | None = None
) -> StateType:
"""Updates (transitions) the state with an action and random number generator."""
raise NotImplementedError

Expand All @@ -62,7 +66,11 @@ def observation(self, state: StateType, params: Params | None = None) -> ObsType
raise NotImplementedError

def reward(
self, state: StateType, action: ActType, next_state: StateType, params: Params | None = None
self,
state: StateType,
action: ActType,
next_state: StateType,
params: Params | None = None,
) -> RewardType:
"""Computes the reward for a given transition between `state`, `action` to `next_state`."""
raise NotImplementedError
Expand All @@ -76,7 +84,11 @@ def initial_info(self, state: StateType, params: Params | None = None) -> dict:
return {}

def transition_info(
self, state: StateType, action: ActType, next_state: StateType, params: Params | None = None
self,
state: StateType,
action: ActType,
next_state: StateType,
params: Params | None = None,
) -> dict:
"""Info dict about a full transition."""
return {}
Expand All @@ -92,7 +104,10 @@ def transform(self, func: Callable[[Callable], Callable]):
self.step_info = func(self.transition_info)

def render_image(
self, state: StateType, render_state: RenderStateType, params: Params | None = None
self,
state: StateType,
render_state: RenderStateType,
params: Params | None = None,
) -> tuple[RenderStateType, np.ndarray]:
"""Show the state."""
raise NotImplementedError
Expand Down

0 comments on commit a19ba2f

Please sign in to comment.