diff --git a/gymnasium/envs/phys2d/cartpole.py b/gymnasium/envs/phys2d/cartpole.py index 6599d5880..2c07f8f35 100644 --- a/gymnasium/envs/phys2d/cartpole.py +++ b/gymnasium/envs/phys2d/cartpole.py @@ -1,14 +1,13 @@ """Implementation of a Jax-accelerated cartpole environment.""" from __future__ import annotations -from dataclasses import dataclass from typing import Any, Tuple import jax import jax.numpy as jnp import numpy as np -from jax.random import PRNGKey from flax import struct +from jax.random import PRNGKey import gymnasium as gym from gymnasium.error import DependencyNotInstalled @@ -19,8 +18,11 @@ RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821 + @struct.dataclass class CartPoleParams: + """Parameters for the jax CartPole environment.""" + gravity: float = 9.8 masscart: float = 1.0 masspole: float = 0.1 @@ -104,14 +106,18 @@ class CartPoleFunctional( observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32) action_space = gym.spaces.Discrete(2) - def initial(self, rng: PRNGKey, params: CartPoleParams | None = CartPoleParams): + def initial(self, rng: PRNGKey, params: CartPoleParams = CartPoleParams): """Initial state generation.""" return jax.random.uniform( key=rng, minval=-params.x_init, maxval=params.x_init, shape=(4,) ) def transition( - self, state: jax.Array, action: int | jax.Array, rng: None = None, params: CartPoleParams | None = CartPoleParams + self, + state: jax.Array, + action: int | jax.Array, + rng: None = None, + params: CartPoleParams = CartPoleParams, ) -> StateType: """Cartpole transition.""" x, x_dot, theta, theta_dot = state @@ -125,7 +131,8 @@ def transition( force + params.polemass_length * theta_dot**2 * sintheta ) / params.total_mass thetaacc = (params.gravity * sintheta - costheta * temp) / ( - params.length * (4.0 / 3.0 - params.masspole * costheta**2 / params.total_mass) + params.length + * (4.0 / 3.0 - params.masspole * costheta**2 / params.total_mass) ) xacc = temp - params.polemass_length * thetaacc * costheta / params.total_mass @@ -138,11 +145,15 @@ def transition( return state - def observation(self, state: jax.Array, params: CartPoleParams | None = CartPoleParams) -> jax.Array: + def observation( + self, state: jax.Array, params: CartPoleParams = CartPoleParams + ) -> jax.Array: """Cartpole observation.""" return state - def terminal(self, state: jax.Array, params: CartPoleParams | None = CartPoleParams) -> jax.Array: + def terminal( + self, state: jax.Array, params: CartPoleParams = CartPoleParams + ) -> jax.Array: """Checks if the state is terminal.""" x, _, theta, _ = state @@ -156,7 +167,11 @@ def terminal(self, state: jax.Array, params: CartPoleParams | None = CartPolePar return terminated def reward( - self, state: StateType, action: ActType, next_state: StateType, params: CartPoleParams | None = CartPoleParams + self, + state: StateType, + action: ActType, + next_state: StateType, + params: CartPoleParams = CartPoleParams, ) -> jax.Array: """Computes the reward for the state transition using the action.""" x, _, theta, _ = state @@ -175,7 +190,7 @@ def render_image( self, state: StateType, render_state: RenderStateType, - params: CartPoleParams | None = CartPoleParams, + params: CartPoleParams = CartPoleParams, ) -> tuple[RenderStateType, np.ndarray]: """Renders an image of the state using the render state.""" try: diff --git a/gymnasium/envs/phys2d/pendulum.py b/gymnasium/envs/phys2d/pendulum.py index 30a3c64a6..478070f7d 100644 --- a/gymnasium/envs/phys2d/pendulum.py +++ b/gymnasium/envs/phys2d/pendulum.py @@ -7,6 +7,7 @@ import jax import jax.numpy as jnp import numpy as np +from flax import struct from jax.random import PRNGKey import gymnasium as gym @@ -19,57 +20,74 @@ RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821 +@struct.dataclass +class PendulumParams: + """Parameters for the jax Pendulum environment.""" + + max_speed: float = 8.0 + dt: float = 0.05 + g: float = 10.0 + m: float = 1.0 + l: float = 1.0 + high_x: float = jnp.pi + high_y: float = 1.0 + screen_dim: int = 500 + + class PendulumFunctional( - #FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType] + FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType, PendulumParams] ): """Pendulum but in jax and functional structure.""" - max_speed = 8 - max_torque = 2.0 - dt = 0.05 - g = 10.0 - m = 1.0 - l = 1.0 - high_x = jnp.pi - high_y = 1.0 - - screen_dim = 500 + max_torque: float = 2.0 observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(3,), dtype=np.float32) action_space = gym.spaces.Box(-max_torque, max_torque, shape=(1,), dtype=np.float32) - def initial(self, rng: PRNGKey): + def initial(self, rng: PRNGKey, params: PendulumParams = PendulumParams): """Initial state generation.""" - high = jnp.array([self.high_x, self.high_y]) + high = jnp.array([params.high_x, params.high_y]) return jax.random.uniform(key=rng, minval=-high, maxval=high, shape=high.shape) def transition( - self, state: jax.Array, action: int | jax.Array, rng: None = None + self, + state: jax.Array, + action: int | jax.Array, + rng: None = None, + params: PendulumParams = PendulumParams, ) -> jax.Array: """Pendulum transition.""" th, thdot = state # th := theta u = action - g = self.g - m = self.m - l = self.l - dt = self.dt + g = params.g + m = params.m + l = params.l + dt = params.dt u = jnp.clip(u, -self.max_torque, self.max_torque)[0] newthdot = thdot + (3 * g / (2 * l) * jnp.sin(th) + 3.0 / (m * l**2) * u) * dt - newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed) + newthdot = jnp.clip(newthdot, -params.max_speed, params.max_speed) newth = th + newthdot * dt new_state = jnp.array([newth, newthdot]) return new_state - def observation(self, state: jax.Array) -> jax.Array: + def observation( + self, state: jax.Array, params: PendulumParams = PendulumParams + ) -> jax.Array: """Generates an observation based on the state.""" theta, thetadot = state return jnp.array([jnp.cos(theta), jnp.sin(theta), thetadot]) - def reward(self, state: StateType, action: ActType, next_state: StateType) -> float: + def reward( + self, + state: StateType, + action: ActType, + next_state: StateType, + params: PendulumParams = PendulumParams, + ) -> float: """Generates the reward based on the state, action and next state.""" th, thdot = state # th := theta u = action @@ -81,14 +99,17 @@ def reward(self, state: StateType, action: ActType, next_state: StateType) -> fl return -costs - def terminal(self, state: StateType) -> bool: + def terminal( + self, state: StateType, params: PendulumParams = PendulumParams + ) -> bool: """Determines if the state is a terminal state.""" return False def render_image( self, state: StateType, - render_state: tuple[pygame.Surface, pygame.time.Clock, float | None], # type: ignore # noqa: F821 + render_state: RenderStateType, + params: PendulumParams = PendulumParams, ) -> tuple[RenderStateType, np.ndarray]: """Renders an RGB image.""" try: @@ -100,12 +121,12 @@ def render_image( ) from e screen, clock, last_u = render_state - surf = pygame.Surface((self.screen_dim, self.screen_dim)) + surf = pygame.Surface((params.screen_dim, params.screen_dim)) surf.fill((255, 255, 255)) bound = 2.2 - scale = self.screen_dim / (bound * 2) - offset = self.screen_dim // 2 + scale = params.screen_dim / (bound * 2) + offset = params.screen_dim // 2 rod_length = 1 * scale rod_width = 0.2 * scale @@ -149,7 +170,6 @@ def render_image( ), ) - # drawing axle gfxdraw.aacircle(surf, offset, offset, int(0.05 * scale), (0, 0, 0)) gfxdraw.filled_circle(surf, offset, offset, int(0.05 * scale), (0, 0, 0)) @@ -161,7 +181,10 @@ def render_image( ) def render_initialise( - self, screen_width: int = 600, screen_height: int = 400 + self, + screen_width: int = 600, + screen_height: int = 400, + params: PendulumParams = PendulumParams, ) -> RenderStateType: """Initialises the render state.""" try: @@ -177,7 +200,11 @@ def render_initialise( return screen, clock, None - def render_close(self, render_state: RenderStateType): + def render_close( + self, + render_state: RenderStateType, + params: PendulumParams = PendulumParams, + ): """Closes the render state.""" try: import pygame diff --git a/gymnasium/functional.py b/gymnasium/functional.py index 77193c435..1d4e76672 100644 --- a/gymnasium/functional.py +++ b/gymnasium/functional.py @@ -18,7 +18,9 @@ class FuncEnv( - Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType, Params] + Generic[ + StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType, Params + ] ): """Base class (template) for functional envs. @@ -53,7 +55,9 @@ def initial(self, rng: Any, params: Params | None = None) -> StateType: """Generates the initial state of the environment with a random number generator.""" raise NotImplementedError - def transition(self, state: StateType, action: ActType, rng: Any, params: Params | None = None) -> StateType: + def transition( + self, state: StateType, action: ActType, rng: Any, params: Params | None = None + ) -> StateType: """Updates (transitions) the state with an action and random number generator.""" raise NotImplementedError @@ -62,7 +66,11 @@ def observation(self, state: StateType, params: Params | None = None) -> ObsType raise NotImplementedError def reward( - self, state: StateType, action: ActType, next_state: StateType, params: Params | None = None + self, + state: StateType, + action: ActType, + next_state: StateType, + params: Params | None = None, ) -> RewardType: """Computes the reward for a given transition between `state`, `action` to `next_state`.""" raise NotImplementedError @@ -76,7 +84,11 @@ def initial_info(self, state: StateType, params: Params | None = None) -> dict: return {} def transition_info( - self, state: StateType, action: ActType, next_state: StateType, params: Params | None = None + self, + state: StateType, + action: ActType, + next_state: StateType, + params: Params | None = None, ) -> dict: """Info dict about a full transition.""" return {} @@ -92,7 +104,10 @@ def transform(self, func: Callable[[Callable], Callable]): self.step_info = func(self.transition_info) def render_image( - self, state: StateType, render_state: RenderStateType, params: Params | None = None + self, + state: StateType, + render_state: RenderStateType, + params: Params | None = None, ) -> tuple[RenderStateType, np.ndarray]: """Show the state.""" raise NotImplementedError