From 4be15eb63c6b6dcdf11badfb37e09e9dde64fb0a Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Tue, 5 Dec 2023 02:19:45 +0100 Subject: [PATCH] WiP, changes to jax funcenv params handling --- gymnasium/envs/phys2d/cartpole.py | 86 ++++++++++++++++++++----------- gymnasium/envs/phys2d/pendulum.py | 2 +- gymnasium/functional.py | 24 +++++---- 3 files changed, 71 insertions(+), 41 deletions(-) diff --git a/gymnasium/envs/phys2d/cartpole.py b/gymnasium/envs/phys2d/cartpole.py index 6e55c5dff..6599d5880 100644 --- a/gymnasium/envs/phys2d/cartpole.py +++ b/gymnasium/envs/phys2d/cartpole.py @@ -1,12 +1,14 @@ """Implementation of a Jax-accelerated cartpole environment.""" from __future__ import annotations +from dataclasses import dataclass from typing import Any, Tuple import jax import jax.numpy as jnp import numpy as np from jax.random import PRNGKey +from flax import struct import gymnasium as gym from gymnasium.error import DependencyNotInstalled @@ -17,9 +19,26 @@ RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821 +@struct.dataclass +class CartPoleParams: + gravity: float = 9.8 + masscart: float = 1.0 + masspole: float = 0.1 + total_mass: float = masspole + masscart + length: float = 0.5 + polemass_length: float = masspole + length + force_mag: float = 10.0 + tau: float = 0.02 + theta_threshold_radians: float = 12 * 2 * np.pi / 360 + x_threshold: float = 2.4 + x_init: float = 0.05 + + screen_width: int = 600 + screen_height: int = 400 + class CartPoleFunctional( - FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType] + FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType, CartPoleParams] ): """Cartpole but in jax and functional. @@ -85,68 +104,68 @@ class CartPoleFunctional( observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32) action_space = gym.spaces.Discrete(2) - def initial(self, rng: PRNGKey): + def initial(self, rng: PRNGKey, params: CartPoleParams | None = CartPoleParams): """Initial state generation.""" return jax.random.uniform( - key=rng, minval=-self.x_init, maxval=self.x_init, shape=(4,) + key=rng, minval=-params.x_init, maxval=params.x_init, shape=(4,) ) def transition( - self, state: jax.Array, action: int | jax.Array, rng: None = None + self, state: jax.Array, action: int | jax.Array, rng: None = None, params: CartPoleParams | None = CartPoleParams ) -> StateType: """Cartpole transition.""" x, x_dot, theta, theta_dot = state - force = jnp.sign(action - 0.5) * self.force_mag + force = jnp.sign(action - 0.5) * params.force_mag costheta = jnp.cos(theta) sintheta = jnp.sin(theta) # For the interested reader: # https://coneural.org/florian/papers/05_cart_pole.pdf temp = ( - force + self.polemass_length * theta_dot**2 * sintheta - ) / self.total_mass - thetaacc = (self.gravity * sintheta - costheta * temp) / ( - self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass) + force + params.polemass_length * theta_dot**2 * sintheta + ) / params.total_mass + thetaacc = (params.gravity * sintheta - costheta * temp) / ( + params.length * (4.0 / 3.0 - params.masspole * costheta**2 / params.total_mass) ) - xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass + xacc = temp - params.polemass_length * thetaacc * costheta / params.total_mass - x = x + self.tau * x_dot - x_dot = x_dot + self.tau * xacc - theta = theta + self.tau * theta_dot - theta_dot = theta_dot + self.tau * thetaacc + x = x + params.tau * x_dot + x_dot = x_dot + params.tau * xacc + theta = theta + params.tau * theta_dot + theta_dot = theta_dot + params.tau * thetaacc state = jnp.array((x, x_dot, theta, theta_dot), dtype=jnp.float32) return state - def observation(self, state: jax.Array) -> jax.Array: + def observation(self, state: jax.Array, params: CartPoleParams | None = CartPoleParams) -> jax.Array: """Cartpole observation.""" return state - def terminal(self, state: jax.Array) -> jax.Array: + def terminal(self, state: jax.Array, params: CartPoleParams | None = CartPoleParams) -> jax.Array: """Checks if the state is terminal.""" x, _, theta, _ = state terminated = ( - (x < -self.x_threshold) - | (x > self.x_threshold) - | (theta < -self.theta_threshold_radians) - | (theta > self.theta_threshold_radians) + (x < -params.x_threshold) + | (x > params.x_threshold) + | (theta < -params.theta_threshold_radians) + | (theta > params.theta_threshold_radians) ) return terminated def reward( - self, state: StateType, action: ActType, next_state: StateType + self, state: StateType, action: ActType, next_state: StateType, params: CartPoleParams | None = CartPoleParams ) -> jax.Array: """Computes the reward for the state transition using the action.""" x, _, theta, _ = state terminated = ( - (x < -self.x_threshold) - | (x > self.x_threshold) - | (theta < -self.theta_threshold_radians) - | (theta > self.theta_threshold_radians) + (x < -params.x_threshold) + | (x > params.x_threshold) + | (theta < -params.theta_threshold_radians) + | (theta > params.theta_threshold_radians) ) reward = jax.lax.cond(terminated, lambda: 0.0, lambda: 1.0) @@ -156,6 +175,7 @@ def render_image( self, state: StateType, render_state: RenderStateType, + params: CartPoleParams | None = CartPoleParams, ) -> tuple[RenderStateType, np.ndarray]: """Renders an image of the state using the render state.""" try: @@ -167,21 +187,21 @@ def render_image( ) from e screen, clock = render_state - world_width = self.x_threshold * 2 - scale = self.screen_width / world_width + world_width = params.x_threshold * 2 + scale = params.screen_width / world_width polewidth = 10.0 - polelen = scale * (2 * self.length) + polelen = scale * (2 * params.length) cartwidth = 50.0 cartheight = 30.0 x = state - surf = pygame.Surface((self.screen_width, self.screen_height)) + surf = pygame.Surface((params.screen_width, params.screen_height)) surf.fill((255, 255, 255)) l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2 axleoffset = cartheight / 4.0 - cartx = x[0] * scale + self.screen_width / 2.0 # MIDDLE OF CART + cartx = x[0] * scale + params.screen_width / 2.0 # MIDDLE OF CART carty = 100 # TOP OF CART cart_coords = [(l, b), (l, t), (r, t), (r, b)] cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords] @@ -218,7 +238,7 @@ def render_image( (129, 132, 203), ) - gfxdraw.hline(surf, 0, self.screen_width, carty, (0, 0, 0)) + gfxdraw.hline(surf, 0, params.screen_width, carty, (0, 0, 0)) surf = pygame.transform.flip(surf, False, True) screen.blit(surf, (0, 0)) @@ -255,6 +275,10 @@ def render_close(self, render_state: RenderStateType) -> None: pygame.display.quit() pygame.quit() + def get_default_params(self, **kwargs) -> CartPoleParams: + """Returns the default parameters for the environment.""" + return CartPoleParams(**kwargs) + class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle): """Jax-based implementation of the CartPole environment.""" diff --git a/gymnasium/envs/phys2d/pendulum.py b/gymnasium/envs/phys2d/pendulum.py index 29333912d..30a3c64a6 100644 --- a/gymnasium/envs/phys2d/pendulum.py +++ b/gymnasium/envs/phys2d/pendulum.py @@ -20,7 +20,7 @@ class PendulumFunctional( - FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType] + #FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType] ): """Pendulum but in jax and functional structure.""" diff --git a/gymnasium/functional.py b/gymnasium/functional.py index e8816ca61..77193c435 100644 --- a/gymnasium/functional.py +++ b/gymnasium/functional.py @@ -14,10 +14,11 @@ RewardType = TypeVar("RewardType") TerminalType = TypeVar("TerminalType") RenderStateType = TypeVar("RenderStateType") +Params = TypeVar("Params") class FuncEnv( - Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType] + Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType, Params] ): """Base class (template) for functional envs. @@ -46,35 +47,36 @@ class FuncEnv( def __init__(self, options: dict[str, Any] | None = None): """Initialize the environment constants.""" self.__dict__.update(options or {}) + self.default_params = self.get_default_params() - def initial(self, rng: Any) -> StateType: + def initial(self, rng: Any, params: Params | None = None) -> StateType: """Generates the initial state of the environment with a random number generator.""" raise NotImplementedError - def transition(self, state: StateType, action: ActType, rng: Any) -> StateType: + def transition(self, state: StateType, action: ActType, rng: Any, params: Params | None = None) -> StateType: """Updates (transitions) the state with an action and random number generator.""" raise NotImplementedError - def observation(self, state: StateType) -> ObsType: + def observation(self, state: StateType, params: Params | None = None) -> ObsType: """Generates an observation for a given state of an environment.""" raise NotImplementedError def reward( - self, state: StateType, action: ActType, next_state: StateType + self, state: StateType, action: ActType, next_state: StateType, params: Params | None = None ) -> RewardType: """Computes the reward for a given transition between `state`, `action` to `next_state`.""" raise NotImplementedError - def terminal(self, state: StateType) -> TerminalType: + def terminal(self, state: StateType, params: Params | None = None) -> TerminalType: """Returns if the state is a final terminal state.""" raise NotImplementedError - def initial_info(self, state: StateType) -> dict: + def initial_info(self, state: StateType, params: Params | None = None) -> dict: """Info dict about a single state.""" return {} def transition_info( - self, state: StateType, action: ActType, next_state: StateType + self, state: StateType, action: ActType, next_state: StateType, params: Params | None = None ) -> dict: """Info dict about a full transition.""" return {} @@ -90,7 +92,7 @@ def transform(self, func: Callable[[Callable], Callable]): self.step_info = func(self.transition_info) def render_image( - self, state: StateType, render_state: RenderStateType + self, state: StateType, render_state: RenderStateType, params: Params | None = None ) -> tuple[RenderStateType, np.ndarray]: """Show the state.""" raise NotImplementedError @@ -102,3 +104,7 @@ def render_initialise(self, **kwargs) -> RenderStateType: def render_close(self, render_state: RenderStateType): """Close the render state.""" raise NotImplementedError + + def get_default_params(self, **kwargs) -> Params | None: + """Get the default params.""" + return None