From 519f0e710c4b9d49573aaf780ae453c4cec32363 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Thu, 15 Aug 2024 14:30:06 +0100 Subject: [PATCH] Re-add functional api to experimental --- gymnasium/__init__.py | 4 ++-- gymnasium/envs/functional_jax_env.py | 2 +- gymnasium/envs/phys2d/cartpole.py | 2 +- gymnasium/envs/phys2d/pendulum.py | 2 +- gymnasium/envs/tabular/blackjack.py | 2 +- gymnasium/envs/tabular/cliffwalking.py | 2 +- gymnasium/experimental/__init__.py | 1 + gymnasium/{ => experimental}/functional.py | 0 tests/envs/functional/test_core.py | 2 +- tests/functional/test_functional.py | 19 +++++++++++++------ 10 files changed, 22 insertions(+), 14 deletions(-) create mode 100644 gymnasium/experimental/__init__.py rename gymnasium/{ => experimental}/functional.py (100%) diff --git a/gymnasium/__init__.py b/gymnasium/__init__.py index 1ea51b05f..1364a6b4a 100644 --- a/gymnasium/__init__.py +++ b/gymnasium/__init__.py @@ -20,7 +20,7 @@ VectorizeMode, register_envs, ) -from gymnasium import spaces, utils, vector, wrappers, error, logger, functional +from gymnasium import spaces, utils, vector, wrappers, error, logger, experimental # Initializing pygame initializes audio connections through SDL. SDL uses alsa by default on all Linux systems # SDL connecting to alsa frequently create these giant lists of warnings every time you import an environment using @@ -64,7 +64,7 @@ "wrappers", "error", "logger", - "functional", + "experimental", ] __version__ = "1.0.0" diff --git a/gymnasium/envs/functional_jax_env.py b/gymnasium/envs/functional_jax_env.py index 88741ad41..1c3c32576 100644 --- a/gymnasium/envs/functional_jax_env.py +++ b/gymnasium/envs/functional_jax_env.py @@ -10,7 +10,7 @@ import gymnasium as gym from gymnasium.envs.registration import EnvSpec -from gymnasium.functional import ActType, FuncEnv, StateType +from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.utils import seeding from gymnasium.vector.utils import batch_space diff --git a/gymnasium/envs/phys2d/cartpole.py b/gymnasium/envs/phys2d/cartpole.py index 5f4ee3fe3..99c3b5f44 100644 --- a/gymnasium/envs/phys2d/cartpole.py +++ b/gymnasium/envs/phys2d/cartpole.py @@ -13,7 +13,7 @@ import gymnasium as gym from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv from gymnasium.error import DependencyNotInstalled -from gymnasium.functional import ActType, FuncEnv, StateType +from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.utils import EzPickle diff --git a/gymnasium/envs/phys2d/pendulum.py b/gymnasium/envs/phys2d/pendulum.py index 7cc16293a..2e2538263 100644 --- a/gymnasium/envs/phys2d/pendulum.py +++ b/gymnasium/envs/phys2d/pendulum.py @@ -14,7 +14,7 @@ import gymnasium as gym from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv from gymnasium.error import DependencyNotInstalled -from gymnasium.functional import ActType, FuncEnv, StateType +from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.utils import EzPickle diff --git a/gymnasium/envs/tabular/blackjack.py b/gymnasium/envs/tabular/blackjack.py index 181891169..714bfea07 100644 --- a/gymnasium/envs/tabular/blackjack.py +++ b/gymnasium/envs/tabular/blackjack.py @@ -14,7 +14,7 @@ from gymnasium import spaces from gymnasium.envs.functional_jax_env import FunctionalJaxEnv from gymnasium.error import DependencyNotInstalled -from gymnasium.functional import ActType, FuncEnv, StateType +from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.utils import EzPickle, seeding from gymnasium.wrappers import HumanRendering diff --git a/gymnasium/envs/tabular/cliffwalking.py b/gymnasium/envs/tabular/cliffwalking.py index 21e0b1247..511dabe21 100644 --- a/gymnasium/envs/tabular/cliffwalking.py +++ b/gymnasium/envs/tabular/cliffwalking.py @@ -13,7 +13,7 @@ from gymnasium import spaces from gymnasium.envs.functional_jax_env import FunctionalJaxEnv from gymnasium.error import DependencyNotInstalled -from gymnasium.functional import ActType, FuncEnv, StateType +from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.utils import EzPickle from gymnasium.wrappers import HumanRendering diff --git a/gymnasium/experimental/__init__.py b/gymnasium/experimental/__init__.py new file mode 100644 index 000000000..91119bfec --- /dev/null +++ b/gymnasium/experimental/__init__.py @@ -0,0 +1 @@ +"""Experimental module.""" diff --git a/gymnasium/functional.py b/gymnasium/experimental/functional.py similarity index 100% rename from gymnasium/functional.py rename to gymnasium/experimental/functional.py diff --git a/tests/envs/functional/test_core.py b/tests/envs/functional/test_core.py index 412c624ac..3b7349e04 100644 --- a/tests/envs/functional/test_core.py +++ b/tests/envs/functional/test_core.py @@ -2,7 +2,7 @@ import numpy as np -from gymnasium.functional import FuncEnv +from gymnasium.experimental.functional import FuncEnv class BasicTestEnv(FuncEnv): diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index d4cb2f8f8..6fb2c4e6d 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -6,7 +6,7 @@ import numpy as np -from gymnasium.functional import FuncEnv +from gymnasium.experimental.functional import FuncEnv class GenericTestFuncEnv(FuncEnv): @@ -16,25 +16,32 @@ def __init__(self, options: dict[str, Any] | None = None): """Constructor that allows generic options to be set on the environment.""" super().__init__(options) - def initial(self, rng: Any) -> np.ndarray: + def initial(self, rng: Any, params=None) -> np.ndarray: """Testing initial function.""" return np.array([0, 0], dtype=np.float32) - def observation(self, state: np.ndarray, rng: Any) -> np.ndarray: + def observation(self, state: np.ndarray, rng: Any, params=None) -> np.ndarray: """Testing observation function.""" return state - def transition(self, state: np.ndarray, action: int, rng: None) -> np.ndarray: + def transition( + self, state: np.ndarray, action: int, rng: None, params=None + ) -> np.ndarray: """Testing transition function.""" return state + np.array([0, action], dtype=np.float32) def reward( - self, state: np.ndarray, action: int, next_state: np.ndarray, rng: Any + self, + state: np.ndarray, + action: int, + next_state: np.ndarray, + rng: Any, + params=None, ) -> float: """Testing reward function.""" return 1.0 if next_state[1] > 0 else 0.0 - def terminal(self, state: np.ndarray, rng: Any) -> bool: + def terminal(self, state: np.ndarray, rng: Any, params=None) -> bool: """Testing terminal function.""" return state[1] > 0