Skip to content

Commit

Permalink
Remove gym from skrl.models
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 19, 2024
1 parent 0dc1d0b commit a2b322b
Show file tree
Hide file tree
Showing 12 changed files with 51 additions and 64 deletions.
29 changes: 14 additions & 15 deletions skrl/models/jax/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union

import gym
import gymnasium

import flax
Expand Down Expand Up @@ -35,13 +34,13 @@ def create(cls, *, apply_fn, params, **kwargs):


class Model(flax.linen.Module):
observation_space: Union[int, Sequence[int], gym.Space, gymnasium.Space]
action_space: Union[int, Sequence[int], gym.Space, gymnasium.Space]
observation_space: Union[int, Sequence[int], gymnasium.Space]
action_space: Union[int, Sequence[int], gymnasium.Space]
device: Optional[Union[str, jax.Device]] = None

def __init__(self,
observation_space: Union[int, Sequence[int], gym.Space, gymnasium.Space],
action_space: Union[int, Sequence[int], gym.Space, gymnasium.Space],
observation_space: Union[int, Sequence[int], gymnasium.Space],
action_space: Union[int, Sequence[int], gymnasium.Space],
device: Optional[Union[str, jax.Device]] = None,
parent: Optional[Any] = None,
name: Optional[str] = None) -> None:
Expand All @@ -50,17 +49,17 @@ def __init__(self,
The following properties are defined:
- ``device`` (jax.Device): Device to be used for the computations
- ``observation_space`` (int, sequence of int, gym.Space, gymnasium.Space): Observation/state space
- ``action_space`` (int, sequence of int, gym.Space, gymnasium.Space): Action space
- ``observation_space`` (int, sequence of int, gymnasium.Space): Observation/state space
- ``action_space`` (int, sequence of int, gymnasium.Space): Action space
- ``num_observations`` (int): Number of elements in the observation/state space
- ``num_actions`` (int): Number of elements in the action space
:param observation_space: Observation/state space or shape.
The ``num_observations`` property will contain the size of that space
:type observation_space: int, sequence of int, gym.Space, gymnasium.Space
:type observation_space: int, sequence of int, gymnasium.Space
:param action_space: Action space or shape.
The ``num_actions`` property will contain the size of that space
:type action_space: int, sequence of int, gym.Space, gymnasium.Space
:type action_space: int, sequence of int, gymnasium.Space
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
:type device: str or jax.Device, optional
Expand Down Expand Up @@ -142,7 +141,7 @@ def init_state_dict(self,

def tensor_to_space(self,
tensor: Union[np.ndarray, jax.Array],
space: Union[gym.Space, gymnasium.Space],
space: gymnasium.Space,
start: int = 0) -> Union[Union[np.ndarray, jax.Array], dict]:
"""Map a flat tensor to a Gym/Gymnasium space
Expand All @@ -153,7 +152,7 @@ def tensor_to_space(self,
:param tensor: Tensor to map from
:type tensor: np.ndarray or jax.Array
:param space: Space to map the tensor to
:type space: gym.Space or gymnasium.Space
:type space: gymnasium.Space
:param start: Index of the first element of the tensor to map (default: ``0``)
:type start: int, optional
Expand All @@ -164,8 +163,8 @@ def tensor_to_space(self,
Example::
>>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)),
... 'b': gym.spaces.Discrete(4)})
>>> space = gymnasium.spaces.Dict({'a': gymnasium.spaces.Box(low=-1, high=1, shape=(2, 3)),
... 'b': gymnasium.spaces.Discrete(4)})
>>> tensor = jnp.array([[-0.3, -0.2, -0.1, 0.1, 0.2, 0.3, 2]])
>>>
>>> model.tensor_to_space(tensor, space)
Expand Down Expand Up @@ -198,10 +197,10 @@ def random_act(self,
:rtype: tuple of np.ndarray or jax.Array, None, and dict
"""
# discrete action space (Discrete)
if issubclass(type(self.action_space), gym.spaces.Discrete) or issubclass(type(self.action_space), gymnasium.spaces.Discrete):
if isinstance(self.action_space, gymnasium.spaces.Discrete):
actions = np.random.randint(self.action_space.n, size=(inputs["states"].shape[0], 1))
# continuous action space (Box)
elif issubclass(type(self.action_space), gym.spaces.Box) or issubclass(type(self.action_space), gymnasium.spaces.Box):
elif isinstance(self.action_space, gymnasium.spaces.Box):
actions = np.random.uniform(low=self.action_space.low[0], high=self.action_space.high[0], size=(inputs["states"].shape[0], self.num_actions))
else:
raise NotImplementedError(f"Action space type ({type(self.action_space)}) not supported")
Expand Down
4 changes: 2 additions & 2 deletions skrl/models/jax/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def __init__(self, unnormalized_log_prob: bool = True, role: str = "") -> None:
... x = nn.Dense(self.num_actions)(x)
... return x, {}
...
>>> # given an observation_space: gym.spaces.Box with shape (4,)
>>> # and an action_space: gym.spaces.Discrete with n = 2
>>> # given an observation_space: gymnasium.spaces.Box with shape (4,)
>>> # and an action_space: gymnasium.spaces.Discrete with n = 2
>>> model = Policy(observation_space, action_space)
>>>
>>> print(model)
Expand Down
8 changes: 3 additions & 5 deletions skrl/models/jax/deterministic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Mapping, Optional, Tuple, Union

import gym
import gymnasium

import flax
Expand Down Expand Up @@ -36,8 +35,8 @@ def __init__(self, clip_actions: bool = False, role: str = "") -> None:
... x = nn.Dense(1)(x)
... return x, {}
...
>>> # given an observation_space: gym.spaces.Box with shape (60,)
>>> # and an action_space: gym.spaces.Box with shape (8,)
>>> # given an observation_space: gymnasium.spaces.Box with shape (60,)
>>> # and an action_space: gymnasium.spaces.Box with shape (8,)
>>> model = Value(observation_space, action_space)
>>>
>>> print(model)
Expand All @@ -50,8 +49,7 @@ def __init__(self, clip_actions: bool = False, role: str = "") -> None:
"""
if not hasattr(self, "_d_clip_actions"):
self._d_clip_actions = {}
self._d_clip_actions[role] = clip_actions and (issubclass(type(self.action_space), gym.Space) or \
issubclass(type(self.action_space), gymnasium.Space))
self._d_clip_actions[role] = clip_actions and isinstance(self.action_space, gymnasium.Space)

if self._d_clip_actions[role]:
self.clip_actions_min = jnp.array(self.action_space.low, dtype=jnp.float32)
Expand Down
8 changes: 3 additions & 5 deletions skrl/models/jax/gaussian.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Mapping, Optional, Tuple, Union

from functools import partial
import gym
import gymnasium

import flax
Expand Down Expand Up @@ -102,8 +101,8 @@ def __init__(self,
... x = nn.elu(self.layer_2(x))
... return self.layer_3(x), self.log_std_parameter, {}
...
>>> # given an observation_space: gym.spaces.Box with shape (60,)
>>> # and an action_space: gym.spaces.Box with shape (8,)
>>> # given an observation_space: gymnasium.spaces.Box with shape (60,)
>>> # and an action_space: gymnasium.spaces.Box with shape (8,)
>>> model = Policy(observation_space, action_space)
>>>
>>> print(model)
Expand All @@ -114,8 +113,7 @@ def __init__(self,
device = StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)
)
"""
self._clip_actions = clip_actions and (issubclass(type(self.action_space), gym.Space) or \
issubclass(type(self.action_space), gymnasium.Space))
self._clip_actions = clip_actions and isinstance(self.action_space, gymnasium.Space)

if self._clip_actions:
self.clip_actions_min = jnp.array(self.action_space.low, dtype=jnp.float32)
Expand Down
4 changes: 2 additions & 2 deletions skrl/models/jax/multicategorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def __init__(self, unnormalized_log_prob: bool = True, reduction: str = "sum", r
... x = nn.Dense(self.num_actions)(x)
... return x, {}
...
>>> # given an observation_space: gym.spaces.Box with shape (4,)
>>> # and an action_space: gym.spaces.MultiDiscrete with nvec = [3, 2]
>>> # given an observation_space: gymnasium.spaces.Box with shape (4,)
>>> # and an action_space: gymnasium.spaces.MultiDiscrete with nvec = [3, 2]
>>> model = Policy(observation_space, action_space)
>>>
>>> print(model)
Expand Down
26 changes: 12 additions & 14 deletions skrl/models/torch/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Any, Mapping, Optional, Sequence, Tuple, Union

import collections
import gym
import gymnasium
from packaging import version

import numpy as np
import torch

from skrl import config, logger
Expand All @@ -14,25 +12,25 @@

class Model(torch.nn.Module):
def __init__(self,
observation_space: Union[int, Sequence[int], gym.Space, gymnasium.Space],
action_space: Union[int, Sequence[int], gym.Space, gymnasium.Space],
observation_space: Union[int, Sequence[int], gymnasium.Space],
action_space: Union[int, Sequence[int], gymnasium.Space],
device: Optional[Union[str, torch.device]] = None) -> None:
"""Base class representing a function approximator
The following properties are defined:
- ``device`` (torch.device): Device to be used for the computations
- ``observation_space`` (int, sequence of int, gym.Space, gymnasium.Space): Observation/state space
- ``action_space`` (int, sequence of int, gym.Space, gymnasium.Space): Action space
- ``observation_space`` (int, sequence of int, gymnasium.Space): Observation/state space
- ``action_space`` (int, sequence of int, gymnasium.Space): Action space
- ``num_observations`` (int): Number of elements in the observation/state space
- ``num_actions`` (int): Number of elements in the action space
:param observation_space: Observation/state space or shape.
The ``num_observations`` property will contain the size of that space
:type observation_space: int, sequence of int, gym.Space, gymnasium.Space
:type observation_space: int, sequence of int, gymnasium.Space
:param action_space: Action space or shape.
The ``num_actions`` property will contain the size of that space
:type action_space: int, sequence of int, gym.Space, gymnasium.Space
:type action_space: int, sequence of int, gymnasium.Space
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
:type device: str or torch.device, optional
Expand Down Expand Up @@ -67,7 +65,7 @@ def act(self, inputs, role=""):

def tensor_to_space(self,
tensor: torch.Tensor,
space: Union[gym.Space, gymnasium.Space],
space: gymnasium.Space,
start: int = 0) -> Union[torch.Tensor, dict]:
"""Map a flat tensor to a Gym/Gymnasium space
Expand All @@ -78,7 +76,7 @@ def tensor_to_space(self,
:param tensor: Tensor to map from
:type tensor: torch.Tensor
:param space: Space to map the tensor to
:type space: gym.Space or gymnasium.Space
:type space: gymnasium.Space
:param start: Index of the first element of the tensor to map (default: ``0``)
:type start: int, optional
Expand All @@ -89,8 +87,8 @@ def tensor_to_space(self,
Example::
>>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)),
... 'b': gym.spaces.Discrete(4)})
>>> space = gymnasium.spaces.Dict({'a': gymnasium.spaces.Box(low=-1, high=1, shape=(2, 3)),
... 'b': gymnasium.spaces.Discrete(4)})
>>> tensor = torch.tensor([[-0.3, -0.2, -0.1, 0.1, 0.2, 0.3, 2]])
>>>
>>> model.tensor_to_space(tensor, space)
Expand Down Expand Up @@ -119,10 +117,10 @@ def random_act(self,
:rtype: tuple of torch.Tensor, None, and dict
"""
# discrete action space (Discrete)
if issubclass(type(self.action_space), gym.spaces.Discrete) or issubclass(type(self.action_space), gymnasium.spaces.Discrete):
if isinstance(self.action_space, gymnasium.spaces.Discrete):
return torch.randint(self.action_space.n, (inputs["states"].shape[0], 1), device=self.device), None, {}
# continuous action space (Box)
elif issubclass(type(self.action_space), gym.spaces.Box) or issubclass(type(self.action_space), gymnasium.spaces.Box):
elif isinstance(self.action_space, gymnasium.spaces.Box):
if self._random_distribution is None:
self._random_distribution = torch.distributions.uniform.Uniform(
low=torch.tensor(self.action_space.low[0], device=self.device, dtype=torch.float32),
Expand Down
4 changes: 2 additions & 2 deletions skrl/models/torch/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(self, unnormalized_log_prob: bool = True, role: str = "") -> None:
... def compute(self, inputs, role):
... return self.net(inputs["states"]), {}
...
>>> # given an observation_space: gym.spaces.Box with shape (4,)
>>> # and an action_space: gym.spaces.Discrete with n = 2
>>> # given an observation_space: gymnasium.spaces.Box with shape (4,)
>>> # and an action_space: gymnasium.spaces.Discrete with n = 2
>>> model = Policy(observation_space, action_space)
>>>
>>> print(model)
Expand Down
8 changes: 3 additions & 5 deletions skrl/models/torch/deterministic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Mapping, Tuple, Union

import gym
import gymnasium

import torch
Expand Down Expand Up @@ -36,8 +35,8 @@ def __init__(self, clip_actions: bool = False, role: str = "") -> None:
... def compute(self, inputs, role):
... return self.net(inputs["states"]), {}
...
>>> # given an observation_space: gym.spaces.Box with shape (60,)
>>> # and an action_space: gym.spaces.Box with shape (8,)
>>> # given an observation_space: gymnasium.spaces.Box with shape (60,)
>>> # and an action_space: gymnasium.spaces.Box with shape (8,)
>>> model = Value(observation_space, action_space)
>>>
>>> print(model)
Expand All @@ -51,8 +50,7 @@ def __init__(self, clip_actions: bool = False, role: str = "") -> None:
)
)
"""
self._clip_actions = clip_actions and (issubclass(type(self.action_space), gym.Space) or \
issubclass(type(self.action_space), gymnasium.Space))
self._clip_actions = clip_actions and isinstance(self.action_space, gymnasium.Space)

if self._clip_actions:
self._clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32)
Expand Down
8 changes: 3 additions & 5 deletions skrl/models/torch/gaussian.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Mapping, Tuple, Union

import gym
import gymnasium

import torch
Expand Down Expand Up @@ -57,8 +56,8 @@ def __init__(self,
... def compute(self, inputs, role):
... return self.net(inputs["states"]), self.log_std_parameter, {}
...
>>> # given an observation_space: gym.spaces.Box with shape (60,)
>>> # and an action_space: gym.spaces.Box with shape (8,)
>>> # given an observation_space: gymnasium.spaces.Box with shape (60,)
>>> # and an action_space: gymnasium.spaces.Box with shape (8,)
>>> model = Policy(observation_space, action_space)
>>>
>>> print(model)
Expand All @@ -72,8 +71,7 @@ def __init__(self,
)
)
"""
self._clip_actions = clip_actions and (issubclass(type(self.action_space), gym.Space) or \
issubclass(type(self.action_space), gymnasium.Space))
self._clip_actions = clip_actions and isinstance(self.action_space, gymnasium.Space)

if self._clip_actions:
self._clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32)
Expand Down
4 changes: 2 additions & 2 deletions skrl/models/torch/multicategorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def __init__(self, unnormalized_log_prob: bool = True, reduction: str = "sum", r
... def compute(self, inputs, role):
... return self.net(inputs["states"]), {}
...
>>> # given an observation_space: gym.spaces.Box with shape (4,)
>>> # and an action_space: gym.spaces.MultiDiscrete with nvec = [3, 2]
>>> # given an observation_space: gymnasium.spaces.Box with shape (4,)
>>> # and an action_space: gymnasium.spaces.MultiDiscrete with nvec = [3, 2]
>>> model = Policy(observation_space, action_space)
>>>
>>> print(model)
Expand Down
8 changes: 3 additions & 5 deletions skrl/models/torch/multivariate_gaussian.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Mapping, Tuple, Union

import gym
import gymnasium

import torch
Expand Down Expand Up @@ -50,8 +49,8 @@ def __init__(self,
... def compute(self, inputs, role):
... return self.net(inputs["states"]), self.log_std_parameter, {}
...
>>> # given an observation_space: gym.spaces.Box with shape (60,)
>>> # and an action_space: gym.spaces.Box with shape (8,)
>>> # given an observation_space: gymnasium.spaces.Box with shape (60,)
>>> # and an action_space: gymnasium.spaces.Box with shape (8,)
>>> model = Policy(observation_space, action_space)
>>>
>>> print(model)
Expand All @@ -65,8 +64,7 @@ def __init__(self,
)
)
"""
self._clip_actions = clip_actions and (issubclass(type(self.action_space), gym.Space) or \
issubclass(type(self.action_space), gymnasium.Space))
self._clip_actions = clip_actions and isinstance(self.action_space, gymnasium.Space)

if self._clip_actions:
self._clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32)
Expand Down
4 changes: 2 additions & 2 deletions skrl/models/torch/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def __init__(self, num_envs: int = 1, role: str = "") -> None:
... dim=-1, keepdim=True).view(-1,1)
... return actions, {}
...
>>> # given an observation_space: gym.spaces.Discrete with n=100
>>> # and an action_space: gym.spaces.Discrete with n=5
>>> # given an observation_space: gymnasium.spaces.Discrete with n=100
>>> # and an action_space: gymnasium.spaces.Discrete with n=5
>>> model = GreedyPolicy(observation_space, action_space, num_envs=1)
>>>
>>> print(model)
Expand Down

0 comments on commit a2b322b

Please sign in to comment.