Skip to content

Commit

Permalink
Replace each _get_space_size method call by the compute_space_size ut…
Browse files Browse the repository at this point in the history
…ility in jax
  • Loading branch information
Toni-SM committed Oct 7, 2024
1 parent 17c6f11 commit e003ada
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 202 deletions.
48 changes: 3 additions & 45 deletions skrl/memories/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np

from skrl import config
from skrl.utils.spaces.jax import compute_space_size


# https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function
Expand Down Expand Up @@ -107,49 +108,6 @@ def __len__(self) -> int:
"""
return self.memory_size * self.num_envs if self.filled else self.memory_index * self.num_envs + self.env_index

def _get_space_size(self,
space: Union[int, Tuple[int], gym.Space, gymnasium.Space],
keep_dimensions: bool = False) -> Union[Tuple, int]:
"""Get the size (number of elements) of a space
:param space: Space or shape from which to obtain the number of elements
:type space: int, tuple or list of integers, gym.Space, or gymnasium.Space
:param keep_dimensions: Whether or not to keep the space dimensions (default: ``False``)
:type keep_dimensions: bool, optional
:raises ValueError: If the space is not supported
:return: Size of the space. If ``keep_dimensions`` is True, the space size will be a tuple
:rtype: int or tuple of int
"""
if type(space) in [int, float]:
return (int(space),) if keep_dimensions else int(space)
elif type(space) in [tuple, list]:
return tuple(space) if keep_dimensions else np.prod(space)
elif issubclass(type(space), gym.Space):
if issubclass(type(space), gym.spaces.Discrete):
return (1,) if keep_dimensions else 1
elif issubclass(type(space), gym.spaces.MultiDiscrete):
return space.nvec.shape[0]
elif issubclass(type(space), gym.spaces.Box):
return tuple(space.shape) if keep_dimensions else np.prod(space.shape)
elif issubclass(type(space), gym.spaces.Dict):
if keep_dimensions:
raise ValueError("keep_dimensions=True cannot be used with Dict spaces")
return sum([self._get_space_size(space.spaces[key]) for key in space.spaces])
elif issubclass(type(space), gymnasium.Space):
if issubclass(type(space), gymnasium.spaces.Discrete):
return (1,) if keep_dimensions else 1
elif issubclass(type(space), gymnasium.spaces.MultiDiscrete):
return space.nvec.shape[0]
elif issubclass(type(space), gymnasium.spaces.Box):
return tuple(space.shape) if keep_dimensions else np.prod(space.shape)
elif issubclass(type(space), gymnasium.spaces.Dict):
if keep_dimensions:
raise ValueError("keep_dimensions=True cannot be used with Dict spaces")
return sum([self._get_space_size(space.spaces[key]) for key in space.spaces])
raise ValueError(f"Space type {type(space)} not supported")

def _get_tensors_view(self, name):
if self.tensors_keep_dimensions[name]:
return self.tensors_view[name] if self._views else self.tensors[name].reshape(-1, *self.tensors_keep_dimensions[name])
Expand Down Expand Up @@ -204,7 +162,7 @@ def create_tensor(self,
name: str,
size: Union[int, Tuple[int], gym.Space, gymnasium.Space],
dtype: Optional[np.dtype] = None,
keep_dimensions: bool = True) -> bool:
keep_dimensions: bool = False) -> bool:
"""Create a new internal tensor in memory
The tensor will have a 3-components shape (memory size, number of environments, size).
Expand All @@ -227,7 +185,7 @@ def create_tensor(self,
:rtype: bool
"""
# compute data size
size = self._get_space_size(size, keep_dimensions)
size = compute_space_size(size, occupied_size=True)
# check dtype and size if the tensor exists
if name in self.tensors:
tensor = self.tensors[name]
Expand Down
131 changes: 6 additions & 125 deletions skrl/models/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np

from skrl import config
from skrl.utils.spaces.torch import compute_space_size, unflatten_tensorized_space


@jax.jit
Expand Down Expand Up @@ -100,8 +101,8 @@ def __call__(self, inputs, role):

self.observation_space = observation_space
self.action_space = action_space
self.num_observations = None if observation_space is None else self._get_space_size(observation_space)
self.num_actions = None if action_space is None else self._get_space_size(action_space)
self.num_observations = None if observation_space is None else compute_space_size(observation_space)
self.num_actions = None if action_space is None else compute_space_size(action_space)

self.state_dict: StateDict
self.training = False
Expand Down Expand Up @@ -139,111 +140,15 @@ def init_state_dict(self,
with jax.default_device(self.device):
self.state_dict = StateDict.create(apply_fn=self.apply, params=self.init(key, inputs, role))

def _get_space_size(self,
space: Union[int, Sequence[int], gym.Space, gymnasium.Space],
number_of_elements: bool = True) -> int:
"""Get the size (number of elements) of a space
:param space: Space or shape from which to obtain the number of elements
:type space: int, sequence of int, gym.Space, or gymnasium.Space
:param number_of_elements: Whether the number of elements occupied by the space is returned (default: ``True``).
If ``False``, the shape of the space is returned.
It only affects Discrete and MultiDiscrete spaces
:type number_of_elements: bool, optional
:raises ValueError: If the space is not supported
:return: Size of the space (number of elements)
:rtype: int
Example::
# from int
>>> model._get_space_size(2)
2
# from sequence of int
>>> model._get_space_size([2, 3])
6
# Box space
>>> space = gym.spaces.Box(low=-1, high=1, shape=(2, 3))
>>> model._get_space_size(space)
6
# Discrete space
>>> space = gym.spaces.Discrete(4)
>>> model._get_space_size(space)
4
>>> model._get_space_size(space, number_of_elements=False)
1
# MultiDiscrete space
>>> space = gym.spaces.MultiDiscrete([5, 3, 2])
>>> model._get_space_size(space)
10
>>> model._get_space_size(space, number_of_elements=False)
3
# Dict space
>>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)),
... 'b': gym.spaces.Discrete(4)})
>>> model._get_space_size(space)
10
>>> model._get_space_size(space, number_of_elements=False)
7
"""
size = None
if type(space) in [int, float]:
size = space
elif type(space) in [tuple, list]:
size = np.prod(space)
elif issubclass(type(space), gym.Space):
if issubclass(type(space), gym.spaces.Discrete):
if number_of_elements:
size = space.n
else:
size = 1
elif issubclass(type(space), gym.spaces.MultiDiscrete):
if number_of_elements:
size = np.sum(space.nvec)
else:
size = space.nvec.shape[0]
elif issubclass(type(space), gym.spaces.Box):
size = np.prod(space.shape)
elif issubclass(type(space), gym.spaces.Dict):
size = sum([self._get_space_size(space.spaces[key], number_of_elements) for key in space.spaces])
elif issubclass(type(space), gymnasium.Space):
if issubclass(type(space), gymnasium.spaces.Discrete):
if number_of_elements:
size = space.n
else:
size = 1
elif issubclass(type(space), gymnasium.spaces.MultiDiscrete):
if number_of_elements:
size = np.sum(space.nvec)
else:
size = space.nvec.shape[0]
elif issubclass(type(space), gymnasium.spaces.Box):
size = np.prod(space.shape)
elif issubclass(type(space), gymnasium.spaces.Dict):
size = sum([self._get_space_size(space.spaces[key], number_of_elements) for key in space.spaces])
if size is None:
raise ValueError(f"Space type {type(space)} not supported")
return int(size)

def tensor_to_space(self,
tensor: Union[np.ndarray, jax.Array],
space: Union[gym.Space, gymnasium.Space],
start: int = 0) -> Union[Union[np.ndarray, jax.Array], dict]:
"""Map a flat tensor to a Gym/Gymnasium space
The mapping is done in the following way:
.. warning::
- Tensors belonging to Discrete spaces are returned without modification
- Tensors belonging to Box spaces are reshaped to the corresponding space shape
keeping the first dimension (number of samples) as they are
- Tensors belonging to Dict spaces are mapped into a dictionary with the same keys as the original space
This method is deprecated in favor of the :py:func:`skrl.utils.spaces.jax.unflatten_tensorized_space`
:param tensor: Tensor to map from
:type tensor: np.ndarray or jax.Array
Expand All @@ -268,31 +173,7 @@ def tensor_to_space(self,
[ 0.1, 0.2, 0.3]]], dtype=float32),
'b': Array([[2.]], dtype=float32)}
"""
if issubclass(type(space), gym.Space):
if issubclass(type(space), gym.spaces.Discrete):
return tensor
elif issubclass(type(space), gym.spaces.Box):
return tensor.reshape(tensor.shape[0], *space.shape)
elif issubclass(type(space), gym.spaces.Dict):
output = {}
for k in sorted(space.keys()):
end = start + self._get_space_size(space[k], number_of_elements=False)
output[k] = self.tensor_to_space(tensor[:, start:end], space[k], end)
start = end
return output
else:
if issubclass(type(space), gymnasium.spaces.Discrete):
return tensor
elif issubclass(type(space), gymnasium.spaces.Box):
return tensor.reshape(tensor.shape[0], *space.shape)
elif issubclass(type(space), gymnasium.spaces.Dict):
output = {}
for k in sorted(space.keys()):
end = start + self._get_space_size(space[k], number_of_elements=False)
output[k] = self.tensor_to_space(tensor[:, start:end], space[k], end)
start = end
return output
raise ValueError(f"Space type {type(space)} not supported")
return unflatten_tensorized_space(space, tensor)

def random_act(self,
inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]],
Expand Down
34 changes: 2 additions & 32 deletions skrl/resources/preprocessors/jax/running_standard_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np

from skrl import config
from skrl.utils.spaces.jax import compute_space_size


# https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function
Expand Down Expand Up @@ -100,7 +101,7 @@ def __init__(self,
device_type, device_index = f"{device}:0".split(':')[:2]
self.device = jax.devices(device_type)[int(device_index)]

size = self._get_space_size(size)
size = compute_space_size(size, occupied_size=True)

if self._jax:
with jax.default_device(self.device):
Expand Down Expand Up @@ -140,37 +141,6 @@ def state_dict(self, value: Mapping[str, Union[np.ndarray, jax.Array]]) -> None:
np.copyto(self.running_variance, value["running_variance"])
np.copyto(self.current_count, value["current_count"])

def _get_space_size(self, space: Union[int, Tuple[int], gym.Space, gymnasium.Space]) -> int:
"""Get the size (number of elements) of a space
:param space: Space or shape from which to obtain the number of elements
:type space: int, tuple or list of integers, gym.Space, or gymnasium.Space
:raises ValueError: If the space is not supported
:return: Size of the space data
:rtype: Space size (number of elements)
"""
if type(space) in [int, float]:
return int(space)
elif type(space) in [tuple, list]:
return np.prod(space)
elif issubclass(type(space), gym.Space):
if issubclass(type(space), gym.spaces.Discrete):
return 1
elif issubclass(type(space), gym.spaces.Box):
return np.prod(space.shape)
elif issubclass(type(space), gym.spaces.Dict):
return sum([self._get_space_size(space.spaces[key]) for key in space.spaces])
elif issubclass(type(space), gymnasium.Space):
if issubclass(type(space), gymnasium.spaces.Discrete):
return 1
elif issubclass(type(space), gymnasium.spaces.Box):
return np.prod(space.shape)
elif issubclass(type(space), gymnasium.spaces.Dict):
return sum([self._get_space_size(space.spaces[key]) for key in space.spaces])
raise ValueError(f"Space type {type(space)} not supported")

def _parallel_variance(self,
input_mean: Union[np.ndarray, jax.Array],
input_var: Union[np.ndarray, jax.Array],
Expand Down

0 comments on commit e003ada

Please sign in to comment.