Skip to content

Commit

Permalink
Remove gym from skrl.agents
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 26, 2024
1 parent e2ff56c commit 04c1f4d
Show file tree
Hide file tree
Showing 31 changed files with 132 additions and 163 deletions.
9 changes: 4 additions & 5 deletions skrl/agents/jax/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import copy
import functools
import gym
import gymnasium

import jax
Expand Down Expand Up @@ -172,8 +171,8 @@ class A2C(Agent):
def __init__(self,
models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
device: Optional[Union[str, jax.Device]] = None,
cfg: Optional[dict] = None) -> None:
"""Advantage Actor Critic (A2C)
Expand All @@ -187,9 +186,9 @@ def __init__(self,
for the rest only the environment transitions will be added
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
:param observation_space: Observation/state space or shape (default: ``None``)
:type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type observation_space: int, tuple or list of int, gymnasium.Space or None, optional
:param action_space: Action space or shape (default: ``None``)
:type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type action_space: int, tuple or list of int, gymnasium.Space or None, optional
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
:type device: str or jax.Device, optional
Expand Down
9 changes: 4 additions & 5 deletions skrl/agents/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import datetime
import os
import pickle
import gym
import gymnasium

import flax
Expand All @@ -21,8 +20,8 @@ class Agent:
def __init__(self,
models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
device: Optional[Union[str, jax.Device]] = None,
cfg: Optional[dict] = None) -> None:
"""Base class that represent a RL agent
Expand All @@ -34,9 +33,9 @@ def __init__(self,
for the rest only the environment transitions will be added
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
:param observation_space: Observation/state space or shape (default: ``None``)
:type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type observation_space: int, tuple or list of int, gymnasium.Space or None, optional
:param action_space: Action space or shape (default: ``None``)
:type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type action_space: int, tuple or list of int, gymnasium.Space or None, optional
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
:type device: str or jax.Device, optional
Expand Down
9 changes: 4 additions & 5 deletions skrl/agents/jax/cem/cem.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Mapping, Optional, Tuple, Union

import copy
import gym
import gymnasium

import jax
Expand Down Expand Up @@ -54,8 +53,8 @@ class CEM(Agent):
def __init__(self,
models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
device: Optional[Union[str, jax.Device]] = None,
cfg: Optional[dict] = None) -> None:
"""Cross-Entropy Method (CEM)
Expand All @@ -69,9 +68,9 @@ def __init__(self,
for the rest only the environment transitions will be added
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
:param observation_space: Observation/state space or shape (default: ``None``)
:type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type observation_space: int, tuple or list of int, gymnasium.Space or None, optional
:param action_space: Action space or shape (default: ``None``)
:type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type action_space: int, tuple or list of int, gymnasium.Space or None, optional
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
:type device: str or jax.Device, optional
Expand Down
9 changes: 4 additions & 5 deletions skrl/agents/jax/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import copy
import functools
import gym
import gymnasium

import jax
Expand Down Expand Up @@ -114,8 +113,8 @@ class DDPG(Agent):
def __init__(self,
models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
device: Optional[Union[str, jax.Device]] = None,
cfg: Optional[dict] = None) -> None:
"""Deep Deterministic Policy Gradient (DDPG)
Expand All @@ -129,9 +128,9 @@ def __init__(self,
for the rest only the environment transitions will be added
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
:param observation_space: Observation/state space or shape (default: ``None``)
:type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type observation_space: int, tuple or list of int, gymnasium.Space or None, optional
:param action_space: Action space or shape (default: ``None``)
:type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type action_space: int, tuple or list of int, gymnasium.Space or None, optional
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
:type device: str or jax.Device, optional
Expand Down
9 changes: 4 additions & 5 deletions skrl/agents/jax/dqn/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import copy
import functools
import gym
import gymnasium

import jax
Expand Down Expand Up @@ -92,8 +91,8 @@ class DDQN(Agent):
def __init__(self,
models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
device: Optional[Union[str, jax.Device]] = None,
cfg: Optional[dict] = None) -> None:
"""Double Deep Q-Network (DDQN)
Expand All @@ -107,9 +106,9 @@ def __init__(self,
for the rest only the environment transitions will be added
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
:param observation_space: Observation/state space or shape (default: ``None``)
:type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type observation_space: int, tuple or list of int, gymnasium.Space or None, optional
:param action_space: Action space or shape (default: ``None``)
:type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type action_space: int, tuple or list of int, gymnasium.Space or None, optional
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
:type device: str or jax.Device, optional
Expand Down
9 changes: 4 additions & 5 deletions skrl/agents/jax/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import copy
import functools
import gym
import gymnasium

import jax
Expand Down Expand Up @@ -89,8 +88,8 @@ class DQN(Agent):
def __init__(self,
models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
device: Optional[Union[str, jax.Device]] = None,
cfg: Optional[dict] = None) -> None:
"""Deep Q-Network (DQN)
Expand All @@ -104,9 +103,9 @@ def __init__(self,
for the rest only the environment transitions will be added
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
:param observation_space: Observation/state space or shape (default: ``None``)
:type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type observation_space: int, tuple or list of int, gymnasium.Space or None, optional
:param action_space: Action space or shape (default: ``None``)
:type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type action_space: int, tuple or list of int, gymnasium.Space or None, optional
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
:type device: str or jax.Device, optional
Expand Down
9 changes: 4 additions & 5 deletions skrl/agents/jax/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import copy
import functools
import gym
import gymnasium

import jax
Expand Down Expand Up @@ -191,8 +190,8 @@ class PPO(Agent):
def __init__(self,
models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
device: Optional[Union[str, jax.Device]] = None,
cfg: Optional[dict] = None) -> None:
"""Proximal Policy Optimization (PPO)
Expand All @@ -206,9 +205,9 @@ def __init__(self,
for the rest only the environment transitions will be added
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
:param observation_space: Observation/state space or shape (default: ``None``)
:type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type observation_space: int, tuple or list of int, gymnasium.Space or None, optional
:param action_space: Action space or shape (default: ``None``)
:type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type action_space: int, tuple or list of int, gymnasium.Space or None, optional
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
:type device: str or jax.Device, optional
Expand Down
9 changes: 4 additions & 5 deletions skrl/agents/jax/rpo/rpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import copy
import functools
import gym
import gymnasium

import jax
Expand Down Expand Up @@ -194,8 +193,8 @@ class RPO(Agent):
def __init__(self,
models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
device: Optional[Union[str, jax.Device]] = None,
cfg: Optional[dict] = None) -> None:
"""Robust Policy Optimization (RPO)
Expand All @@ -209,9 +208,9 @@ def __init__(self,
for the rest only the environment transitions will be added
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
:param observation_space: Observation/state space or shape (default: ``None``)
:type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type observation_space: int, tuple or list of int, gymnasium.Space or None, optional
:param action_space: Action space or shape (default: ``None``)
:type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type action_space: int, tuple or list of int, gymnasium.Space or None, optional
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
:type device: str or jax.Device, optional
Expand Down
13 changes: 6 additions & 7 deletions skrl/agents/jax/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import copy
import functools
import gym
import gymnasium

import flax
Expand Down Expand Up @@ -125,8 +124,8 @@ class SAC(Agent):
def __init__(self,
models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
device: Optional[Union[str, jax.Device]] = None,
cfg: Optional[dict] = None) -> None:
"""Soft Actor-Critic (SAC)
Expand All @@ -140,9 +139,9 @@ def __init__(self,
for the rest only the environment transitions will be added
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
:param observation_space: Observation/state space or shape (default: ``None``)
:type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type observation_space: int, tuple or list of int, gymnasium.Space or None, optional
:param action_space: Action space or shape (default: ``None``)
:type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type action_space: int, tuple or list of int, gymnasium.Space or None, optional
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
:type device: str or jax.Device, optional
Expand Down Expand Up @@ -213,9 +212,9 @@ def __init__(self,
if self._learn_entropy:
self._target_entropy = self.cfg["target_entropy"]
if self._target_entropy is None:
if issubclass(type(self.action_space), gym.spaces.Box) or issubclass(type(self.action_space), gymnasium.spaces.Box):
if issubclass(type(self.action_space), gymnasium.spaces.Box):
self._target_entropy = -np.prod(self.action_space.shape).astype(np.float32)
elif issubclass(type(self.action_space), gym.spaces.Discrete) or issubclass(type(self.action_space), gymnasium.spaces.Discrete):
elif issubclass(type(self.action_space), gymnasium.spaces.Discrete):
self._target_entropy = -self.action_space.n
else:
self._target_entropy = 0
Expand Down
9 changes: 4 additions & 5 deletions skrl/agents/jax/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import copy
import functools
import gym
import gymnasium

import jax
Expand Down Expand Up @@ -132,8 +131,8 @@ class TD3(Agent):
def __init__(self,
models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
device: Optional[Union[str, jax.Device]] = None,
cfg: Optional[dict] = None) -> None:
"""Twin Delayed DDPG (TD3)
Expand All @@ -147,9 +146,9 @@ def __init__(self,
for the rest only the environment transitions will be added
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
:param observation_space: Observation/state space or shape (default: ``None``)
:type observation_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type observation_space: int, tuple or list of int, gymnasium.Space or None, optional
:param action_space: Action space or shape (default: ``None``)
:type action_space: int, tuple or list of int, gym.Space, gymnasium.Space or None, optional
:type action_space: int, tuple or list of int, gymnasium.Space or None, optional
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
:type device: str or jax.Device, optional
Expand Down
Loading

0 comments on commit 04c1f4d

Please sign in to comment.