Skip to content

Commit

Permalink
Re-add functional api to experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed Aug 15, 2024
1 parent 61592c0 commit 519f0e7
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 14 deletions.
4 changes: 2 additions & 2 deletions gymnasium/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
VectorizeMode,
register_envs,
)
from gymnasium import spaces, utils, vector, wrappers, error, logger, functional
from gymnasium import spaces, utils, vector, wrappers, error, logger, experimental

# Initializing pygame initializes audio connections through SDL. SDL uses alsa by default on all Linux systems
# SDL connecting to alsa frequently create these giant lists of warnings every time you import an environment using
Expand Down Expand Up @@ -64,7 +64,7 @@
"wrappers",
"error",
"logger",
"functional",
"experimental",
]
__version__ = "1.0.0"

Expand Down
2 changes: 1 addition & 1 deletion gymnasium/envs/functional_jax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import gymnasium as gym
from gymnasium.envs.registration import EnvSpec
from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import seeding
from gymnasium.vector.utils import batch_space

Expand Down
2 changes: 1 addition & 1 deletion gymnasium/envs/phys2d/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import gymnasium as gym
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv
from gymnasium.error import DependencyNotInstalled
from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle


Expand Down
2 changes: 1 addition & 1 deletion gymnasium/envs/phys2d/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import gymnasium as gym
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv
from gymnasium.error import DependencyNotInstalled
from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle


Expand Down
2 changes: 1 addition & 1 deletion gymnasium/envs/tabular/blackjack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from gymnasium import spaces
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv
from gymnasium.error import DependencyNotInstalled
from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle, seeding
from gymnasium.wrappers import HumanRendering

Expand Down
2 changes: 1 addition & 1 deletion gymnasium/envs/tabular/cliffwalking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from gymnasium import spaces
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv
from gymnasium.error import DependencyNotInstalled
from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle
from gymnasium.wrappers import HumanRendering

Expand Down
1 change: 1 addition & 0 deletions gymnasium/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Experimental module."""
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/envs/functional/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from gymnasium.functional import FuncEnv
from gymnasium.experimental.functional import FuncEnv


class BasicTestEnv(FuncEnv):
Expand Down
19 changes: 13 additions & 6 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from gymnasium.functional import FuncEnv
from gymnasium.experimental.functional import FuncEnv


class GenericTestFuncEnv(FuncEnv):
Expand All @@ -16,25 +16,32 @@ def __init__(self, options: dict[str, Any] | None = None):
"""Constructor that allows generic options to be set on the environment."""
super().__init__(options)

def initial(self, rng: Any) -> np.ndarray:
def initial(self, rng: Any, params=None) -> np.ndarray:
"""Testing initial function."""
return np.array([0, 0], dtype=np.float32)

def observation(self, state: np.ndarray, rng: Any) -> np.ndarray:
def observation(self, state: np.ndarray, rng: Any, params=None) -> np.ndarray:
"""Testing observation function."""
return state

def transition(self, state: np.ndarray, action: int, rng: None) -> np.ndarray:
def transition(
self, state: np.ndarray, action: int, rng: None, params=None
) -> np.ndarray:
"""Testing transition function."""
return state + np.array([0, action], dtype=np.float32)

def reward(
self, state: np.ndarray, action: int, next_state: np.ndarray, rng: Any
self,
state: np.ndarray,
action: int,
next_state: np.ndarray,
rng: Any,
params=None,
) -> float:
"""Testing reward function."""
return 1.0 if next_state[1] > 0 else 0.0

def terminal(self, state: np.ndarray, rng: Any) -> bool:
def terminal(self, state: np.ndarray, rng: Any, params=None) -> bool:
"""Testing terminal function."""
return state[1] > 0

Expand Down

0 comments on commit 519f0e7

Please sign in to comment.