From 3b6184dc9adb9fe91de4292f5871130f4b72addf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 4 Nov 2024 12:11:18 -0500 Subject: [PATCH 1/8] Configure pre-commit hooks --- .pre-commit-config.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5cc902e4..7523f9f4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,6 +14,7 @@ repos: - id: end-of-file-fixer - id: name-tests-test args: ["--pytest-test-first"] + exclude: ^(tests/stategies.py|tests/utils.py) - id: no-commit-to-branch - id: trailing-whitespace - repo: https://github.com/codespell-project/codespell @@ -26,6 +27,8 @@ repos: rev: 24.8.0 hooks: - id: black + args: ["--line-length=120"] + exclude: ^(docs/|tests/) - repo: https://github.com/pycqa/isort rev: 5.13.2 hooks: From d334f22f99585fff9d8f8bbdfb73af665e57b5cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 4 Nov 2024 12:17:21 -0500 Subject: [PATCH 2/8] Configure black in pre-commit and pyproject.toml --- .pre-commit-config.yaml | 2 +- pyproject.toml | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7523f9f4..058aec63 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: hooks: - id: black args: ["--line-length=120"] - exclude: ^(docs/|tests/) + exclude: ^(docs/) - repo: https://github.com/pycqa/isort rev: 5.13.2 hooks: diff --git a/pyproject.toml b/pyproject.toml index d3d04cb4..5e0da0f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,11 @@ tests = [ [tool.black] line-length = 120 +extend-exclude = """ +( + ^/docs +) +""" [tool.codespell] From 6ef70fe24fe2bb64847f995cd21964cf2bf9fa44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 4 Nov 2024 12:35:39 -0500 Subject: [PATCH 3/8] Configure codespell in pre-commit --- .pre-commit-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 058aec63..bd2fa15d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,6 +21,7 @@ repos: rev: v2.3.0 hooks: - id: codespell + exclude: ^(docs/source/_static|docs/_build|pyproject.toml) additional_dependencies: - tomli - repo: https://github.com/python/black From 6942e7c3c7c92e597b016880312a1cea3f338ca7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 4 Nov 2024 16:34:44 -0500 Subject: [PATCH 4/8] Ignore formating config sections --- skrl/agents/jax/a2c/a2c.py | 2 ++ skrl/agents/jax/cem/cem.py | 2 ++ skrl/agents/jax/ddpg/ddpg.py | 2 ++ skrl/agents/jax/dqn/ddqn.py | 2 ++ skrl/agents/jax/dqn/dqn.py | 2 ++ skrl/agents/jax/ppo/ppo.py | 2 ++ skrl/agents/jax/rpo/rpo.py | 2 ++ skrl/agents/jax/sac/sac.py | 2 ++ skrl/agents/jax/td3/td3.py | 2 ++ skrl/agents/torch/a2c/a2c.py | 2 ++ skrl/agents/torch/a2c/a2c_rnn.py | 2 ++ skrl/agents/torch/amp/amp.py | 2 ++ skrl/agents/torch/cem/cem.py | 2 ++ skrl/agents/torch/ddpg/ddpg.py | 2 ++ skrl/agents/torch/ddpg/ddpg_rnn.py | 2 ++ skrl/agents/torch/dqn/ddqn.py | 2 ++ skrl/agents/torch/dqn/dqn.py | 2 ++ skrl/agents/torch/ppo/ppo.py | 2 ++ skrl/agents/torch/ppo/ppo_rnn.py | 2 ++ skrl/agents/torch/q_learning/q_learning.py | 2 ++ skrl/agents/torch/rpo/rpo.py | 2 ++ skrl/agents/torch/rpo/rpo_rnn.py | 2 ++ skrl/agents/torch/sac/sac.py | 2 ++ skrl/agents/torch/sac/sac_rnn.py | 2 ++ skrl/agents/torch/sarsa/sarsa.py | 2 ++ skrl/agents/torch/td3/td3.py | 2 ++ skrl/agents/torch/td3/td3_rnn.py | 2 ++ skrl/agents/torch/trpo/trpo.py | 2 ++ skrl/agents/torch/trpo/trpo_rnn.py | 2 ++ skrl/multi_agents/jax/ippo/ippo.py | 2 ++ skrl/multi_agents/jax/mappo/mappo.py | 2 ++ skrl/multi_agents/torch/ippo/ippo.py | 2 ++ skrl/multi_agents/torch/mappo/mappo.py | 2 ++ skrl/trainers/jax/sequential.py | 2 ++ skrl/trainers/jax/step.py | 2 ++ skrl/trainers/torch/parallel.py | 2 ++ skrl/trainers/torch/sequential.py | 2 ++ skrl/trainers/torch/step.py | 2 ++ 38 files changed, 76 insertions(+) diff --git a/skrl/agents/jax/a2c/a2c.py b/skrl/agents/jax/a2c/a2c.py index 16b533a6..1af1c45e 100644 --- a/skrl/agents/jax/a2c/a2c.py +++ b/skrl/agents/jax/a2c/a2c.py @@ -16,6 +16,7 @@ from skrl.resources.schedulers.jax import KLAdaptiveLR +# fmt: off # [start-config-dict-jax] A2C_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -56,6 +57,7 @@ } } # [end-config-dict-jax] +# fmt: on def compute_gae(rewards: np.ndarray, diff --git a/skrl/agents/jax/cem/cem.py b/skrl/agents/jax/cem/cem.py index f47f5b1f..3c74c634 100644 --- a/skrl/agents/jax/cem/cem.py +++ b/skrl/agents/jax/cem/cem.py @@ -15,6 +15,7 @@ from skrl.resources.optimizers.jax import Adam +# fmt: off # [start-config-dict-jax] CEM_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -47,6 +48,7 @@ } } # [end-config-dict-jax] +# fmt: on class CEM(Agent): diff --git a/skrl/agents/jax/ddpg/ddpg.py b/skrl/agents/jax/ddpg/ddpg.py index efec21d0..a00646ac 100644 --- a/skrl/agents/jax/ddpg/ddpg.py +++ b/skrl/agents/jax/ddpg/ddpg.py @@ -15,6 +15,7 @@ from skrl.resources.optimizers.jax import Adam +# fmt: off # [start-config-dict-jax] DDPG_DEFAULT_CONFIG = { "gradient_steps": 1, # gradient steps @@ -58,6 +59,7 @@ } } # [end-config-dict-jax] +# fmt: on # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function diff --git a/skrl/agents/jax/dqn/ddqn.py b/skrl/agents/jax/dqn/ddqn.py index 76868e68..6b4f2750 100644 --- a/skrl/agents/jax/dqn/ddqn.py +++ b/skrl/agents/jax/dqn/ddqn.py @@ -15,6 +15,7 @@ from skrl.resources.optimizers.jax import Adam +# fmt: off # [start-config-dict-jax] DDQN_DEFAULT_CONFIG = { "gradient_steps": 1, # gradient steps @@ -57,6 +58,7 @@ } } # [end-config-dict-jax] +# fmt: on # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function diff --git a/skrl/agents/jax/dqn/dqn.py b/skrl/agents/jax/dqn/dqn.py index 30629c09..e82fe82f 100644 --- a/skrl/agents/jax/dqn/dqn.py +++ b/skrl/agents/jax/dqn/dqn.py @@ -15,6 +15,7 @@ from skrl.resources.optimizers.jax import Adam +# fmt: off # [start-config-dict-jax] DQN_DEFAULT_CONFIG = { "gradient_steps": 1, # gradient steps @@ -57,6 +58,7 @@ } } # [end-config-dict-jax] +# fmt: on # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function diff --git a/skrl/agents/jax/ppo/ppo.py b/skrl/agents/jax/ppo/ppo.py index 7fde2472..bf52b512 100644 --- a/skrl/agents/jax/ppo/ppo.py +++ b/skrl/agents/jax/ppo/ppo.py @@ -16,6 +16,7 @@ from skrl.resources.schedulers.jax import KLAdaptiveLR +# fmt: off # [start-config-dict-jax] PPO_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -63,6 +64,7 @@ } } # [end-config-dict-jax] +# fmt: on def compute_gae(rewards: np.ndarray, diff --git a/skrl/agents/jax/rpo/rpo.py b/skrl/agents/jax/rpo/rpo.py index c0373627..98ed8a82 100644 --- a/skrl/agents/jax/rpo/rpo.py +++ b/skrl/agents/jax/rpo/rpo.py @@ -16,6 +16,7 @@ from skrl.resources.schedulers.jax import KLAdaptiveLR +# fmt: off # [start-config-dict-jax] RPO_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -64,6 +65,7 @@ } } # [end-config-dict-jax] +# fmt: on def compute_gae(rewards: np.ndarray, diff --git a/skrl/agents/jax/sac/sac.py b/skrl/agents/jax/sac/sac.py index 7656c26d..6bdc5e05 100644 --- a/skrl/agents/jax/sac/sac.py +++ b/skrl/agents/jax/sac/sac.py @@ -16,6 +16,7 @@ from skrl.resources.optimizers.jax import Adam +# fmt: off # [start-config-dict-jax] SAC_DEFAULT_CONFIG = { "gradient_steps": 1, # gradient steps @@ -57,6 +58,7 @@ } } # [end-config-dict-jax] +# fmt: on # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function diff --git a/skrl/agents/jax/td3/td3.py b/skrl/agents/jax/td3/td3.py index 23f4885a..b6057d32 100644 --- a/skrl/agents/jax/td3/td3.py +++ b/skrl/agents/jax/td3/td3.py @@ -15,6 +15,7 @@ from skrl.resources.optimizers.jax import Adam +# fmt: off # [start-config-dict-jax] TD3_DEFAULT_CONFIG = { "gradient_steps": 1, # gradient steps @@ -62,6 +63,7 @@ } } # [end-config-dict-jax] +# fmt: on # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function diff --git a/skrl/agents/torch/a2c/a2c.py b/skrl/agents/torch/a2c/a2c.py index 9c9cf9b6..12df988d 100644 --- a/skrl/agents/torch/a2c/a2c.py +++ b/skrl/agents/torch/a2c/a2c.py @@ -15,6 +15,7 @@ from skrl.resources.schedulers.torch import KLAdaptiveLR +# fmt: off # [start-config-dict-torch] A2C_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -55,6 +56,7 @@ } } # [end-config-dict-torch] +# fmt: on class A2C(Agent): diff --git a/skrl/agents/torch/a2c/a2c_rnn.py b/skrl/agents/torch/a2c/a2c_rnn.py index 97cc93e1..c087671a 100644 --- a/skrl/agents/torch/a2c/a2c_rnn.py +++ b/skrl/agents/torch/a2c/a2c_rnn.py @@ -15,6 +15,7 @@ from skrl.resources.schedulers.torch import KLAdaptiveLR +# fmt: off # [start-config-dict-torch] A2C_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -55,6 +56,7 @@ } } # [end-config-dict-torch] +# fmt: on class A2C_RNN(Agent): diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py index 181e5ac6..1cc857ee 100644 --- a/skrl/agents/torch/amp/amp.py +++ b/skrl/agents/torch/amp/amp.py @@ -15,6 +15,7 @@ from skrl.models.torch import Model +# fmt: off # [start-config-dict-torch] AMP_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -72,6 +73,7 @@ } } # [end-config-dict-torch] +# fmt: on class AMP(Agent): diff --git a/skrl/agents/torch/cem/cem.py b/skrl/agents/torch/cem/cem.py index 864b6b20..735250e4 100644 --- a/skrl/agents/torch/cem/cem.py +++ b/skrl/agents/torch/cem/cem.py @@ -12,6 +12,7 @@ from skrl.models.torch import Model +# fmt: off # [start-config-dict-torch] CEM_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -44,6 +45,7 @@ } } # [end-config-dict-torch] +# fmt: on class CEM(Agent): diff --git a/skrl/agents/torch/ddpg/ddpg.py b/skrl/agents/torch/ddpg/ddpg.py index 1e3d2690..16f094d8 100644 --- a/skrl/agents/torch/ddpg/ddpg.py +++ b/skrl/agents/torch/ddpg/ddpg.py @@ -13,6 +13,7 @@ from skrl.models.torch import Model +# fmt: off # [start-config-dict-torch] DDPG_DEFAULT_CONFIG = { "gradient_steps": 1, # gradient steps @@ -56,6 +57,7 @@ } } # [end-config-dict-torch] +# fmt: on class DDPG(Agent): diff --git a/skrl/agents/torch/ddpg/ddpg_rnn.py b/skrl/agents/torch/ddpg/ddpg_rnn.py index 36a98fee..6d0dd829 100644 --- a/skrl/agents/torch/ddpg/ddpg_rnn.py +++ b/skrl/agents/torch/ddpg/ddpg_rnn.py @@ -13,6 +13,7 @@ from skrl.models.torch import Model +# fmt: off # [start-config-dict-torch] DDPG_DEFAULT_CONFIG = { "gradient_steps": 1, # gradient steps @@ -56,6 +57,7 @@ } } # [end-config-dict-torch] +# fmt: on class DDPG_RNN(Agent): diff --git a/skrl/agents/torch/dqn/ddqn.py b/skrl/agents/torch/dqn/ddqn.py index d7e93886..c2e84d02 100644 --- a/skrl/agents/torch/dqn/ddqn.py +++ b/skrl/agents/torch/dqn/ddqn.py @@ -13,6 +13,7 @@ from skrl.models.torch import Model +# fmt: off # [start-config-dict-torch] DDQN_DEFAULT_CONFIG = { "gradient_steps": 1, # gradient steps @@ -55,6 +56,7 @@ } } # [end-config-dict-torch] +# fmt: on class DDQN(Agent): diff --git a/skrl/agents/torch/dqn/dqn.py b/skrl/agents/torch/dqn/dqn.py index 03ffa320..5fd5ba48 100644 --- a/skrl/agents/torch/dqn/dqn.py +++ b/skrl/agents/torch/dqn/dqn.py @@ -13,6 +13,7 @@ from skrl.models.torch import Model +# fmt: off # [start-config-dict-torch] DQN_DEFAULT_CONFIG = { "gradient_steps": 1, # gradient steps @@ -55,6 +56,7 @@ } } # [end-config-dict-torch] +# fmt: on class DQN(Agent): diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py index 21124fdf..26febbf2 100644 --- a/skrl/agents/torch/ppo/ppo.py +++ b/skrl/agents/torch/ppo/ppo.py @@ -15,6 +15,7 @@ from skrl.resources.schedulers.torch import KLAdaptiveLR +# fmt: off # [start-config-dict-torch] PPO_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -64,6 +65,7 @@ } } # [end-config-dict-torch] +# fmt: on class PPO(Agent): diff --git a/skrl/agents/torch/ppo/ppo_rnn.py b/skrl/agents/torch/ppo/ppo_rnn.py index 59d8cbe2..f19ca3c2 100644 --- a/skrl/agents/torch/ppo/ppo_rnn.py +++ b/skrl/agents/torch/ppo/ppo_rnn.py @@ -15,6 +15,7 @@ from skrl.resources.schedulers.torch import KLAdaptiveLR +# fmt: off # [start-config-dict-torch] PPO_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -62,6 +63,7 @@ } } # [end-config-dict-torch] +# fmt: on class PPO_RNN(Agent): diff --git a/skrl/agents/torch/q_learning/q_learning.py b/skrl/agents/torch/q_learning/q_learning.py index f9dd3442..ad247627 100644 --- a/skrl/agents/torch/q_learning/q_learning.py +++ b/skrl/agents/torch/q_learning/q_learning.py @@ -10,6 +10,7 @@ from skrl.models.torch import Model +# fmt: off # [start-config-dict-torch] Q_LEARNING_DEFAULT_CONFIG = { "discount_factor": 0.99, # discount factor (gamma) @@ -34,6 +35,7 @@ } } # [end-config-dict-torch] +# fmt: on class Q_LEARNING(Agent): diff --git a/skrl/agents/torch/rpo/rpo.py b/skrl/agents/torch/rpo/rpo.py index 1cc4a18d..25eb462b 100644 --- a/skrl/agents/torch/rpo/rpo.py +++ b/skrl/agents/torch/rpo/rpo.py @@ -15,6 +15,7 @@ from skrl.resources.schedulers.torch import KLAdaptiveLR +# fmt: off # [start-config-dict-torch] RPO_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -63,6 +64,7 @@ } } # [end-config-dict-torch] +# fmt: on class RPO(Agent): diff --git a/skrl/agents/torch/rpo/rpo_rnn.py b/skrl/agents/torch/rpo/rpo_rnn.py index 5f1ee485..c3ea7662 100644 --- a/skrl/agents/torch/rpo/rpo_rnn.py +++ b/skrl/agents/torch/rpo/rpo_rnn.py @@ -15,6 +15,7 @@ from skrl.resources.schedulers.torch import KLAdaptiveLR +# fmt: off # [start-config-dict-torch] RPO_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -63,6 +64,7 @@ } } # [end-config-dict-torch] +# fmt: on class RPO_RNN(Agent): diff --git a/skrl/agents/torch/sac/sac.py b/skrl/agents/torch/sac/sac.py index 78f4a556..4189fa85 100644 --- a/skrl/agents/torch/sac/sac.py +++ b/skrl/agents/torch/sac/sac.py @@ -15,6 +15,7 @@ from skrl.models.torch import Model +# fmt: off # [start-config-dict-torch] SAC_DEFAULT_CONFIG = { "gradient_steps": 1, # gradient steps @@ -56,6 +57,7 @@ } } # [end-config-dict-torch] +# fmt: on class SAC(Agent): diff --git a/skrl/agents/torch/sac/sac_rnn.py b/skrl/agents/torch/sac/sac_rnn.py index 4d8764b4..6b162958 100644 --- a/skrl/agents/torch/sac/sac_rnn.py +++ b/skrl/agents/torch/sac/sac_rnn.py @@ -15,6 +15,7 @@ from skrl.models.torch import Model +# fmt: off # [start-config-dict-torch] SAC_DEFAULT_CONFIG = { "gradient_steps": 1, # gradient steps @@ -56,6 +57,7 @@ } } # [end-config-dict-torch] +# fmt: on class SAC_RNN(Agent): diff --git a/skrl/agents/torch/sarsa/sarsa.py b/skrl/agents/torch/sarsa/sarsa.py index 9f27bc3a..3f079231 100644 --- a/skrl/agents/torch/sarsa/sarsa.py +++ b/skrl/agents/torch/sarsa/sarsa.py @@ -10,6 +10,7 @@ from skrl.models.torch import Model +# fmt: off # [start-config-dict-torch] SARSA_DEFAULT_CONFIG = { "discount_factor": 0.99, # discount factor (gamma) @@ -34,6 +35,7 @@ } } # [end-config-dict-torch] +# fmt: on class SARSA(Agent): diff --git a/skrl/agents/torch/td3/td3.py b/skrl/agents/torch/td3/td3.py index 2b791994..3bc7930b 100644 --- a/skrl/agents/torch/td3/td3.py +++ b/skrl/agents/torch/td3/td3.py @@ -14,6 +14,7 @@ from skrl.models.torch import Model +# fmt: off # [start-config-dict-torch] TD3_DEFAULT_CONFIG = { "gradient_steps": 1, # gradient steps @@ -61,6 +62,7 @@ } } # [end-config-dict-torch] +# fmt: on class TD3(Agent): diff --git a/skrl/agents/torch/td3/td3_rnn.py b/skrl/agents/torch/td3/td3_rnn.py index aeb2fd78..81b7f313 100644 --- a/skrl/agents/torch/td3/td3_rnn.py +++ b/skrl/agents/torch/td3/td3_rnn.py @@ -14,6 +14,7 @@ from skrl.models.torch import Model +# fmt: off # [start-config-dict-torch] TD3_DEFAULT_CONFIG = { "gradient_steps": 1, # gradient steps @@ -61,6 +62,7 @@ } } # [end-config-dict-torch] +# fmt: on class TD3_RNN(Agent): diff --git a/skrl/agents/torch/trpo/trpo.py b/skrl/agents/torch/trpo/trpo.py index 32a3b34c..56757867 100644 --- a/skrl/agents/torch/trpo/trpo.py +++ b/skrl/agents/torch/trpo/trpo.py @@ -14,6 +14,7 @@ from skrl.models.torch import Model +# fmt: off # [start-config-dict-torch] TRPO_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -61,6 +62,7 @@ } } # [end-config-dict-torch] +# fmt: on class TRPO(Agent): diff --git a/skrl/agents/torch/trpo/trpo_rnn.py b/skrl/agents/torch/trpo/trpo_rnn.py index 3599223f..2f6a8e61 100644 --- a/skrl/agents/torch/trpo/trpo_rnn.py +++ b/skrl/agents/torch/trpo/trpo_rnn.py @@ -14,6 +14,7 @@ from skrl.models.torch import Model +# fmt: off # [start-config-dict-torch] TRPO_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -61,6 +62,7 @@ } } # [end-config-dict-torch] +# fmt: on class TRPO_RNN(Agent): diff --git a/skrl/multi_agents/jax/ippo/ippo.py b/skrl/multi_agents/jax/ippo/ippo.py index b1f4284c..992d18e6 100644 --- a/skrl/multi_agents/jax/ippo/ippo.py +++ b/skrl/multi_agents/jax/ippo/ippo.py @@ -15,6 +15,7 @@ from skrl.resources.schedulers.jax import KLAdaptiveLR +# fmt: off # [start-config-dict-jax] IPPO_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -62,6 +63,7 @@ } } # [end-config-dict-jax] +# fmt: on def compute_gae(rewards: np.ndarray, diff --git a/skrl/multi_agents/jax/mappo/mappo.py b/skrl/multi_agents/jax/mappo/mappo.py index 108b1306..ee384e52 100644 --- a/skrl/multi_agents/jax/mappo/mappo.py +++ b/skrl/multi_agents/jax/mappo/mappo.py @@ -15,6 +15,7 @@ from skrl.resources.schedulers.jax import KLAdaptiveLR +# fmt: off # [start-config-dict-jax] MAPPO_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -64,6 +65,7 @@ } } # [end-config-dict-jax] +# fmt: on def compute_gae(rewards: np.ndarray, diff --git a/skrl/multi_agents/torch/ippo/ippo.py b/skrl/multi_agents/torch/ippo/ippo.py index 9ab490d7..b84c7559 100644 --- a/skrl/multi_agents/torch/ippo/ippo.py +++ b/skrl/multi_agents/torch/ippo/ippo.py @@ -15,6 +15,7 @@ from skrl.resources.schedulers.torch import KLAdaptiveLR +# fmt: off # [start-config-dict-torch] IPPO_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -62,6 +63,7 @@ } } # [end-config-dict-torch] +# fmt: on class IPPO(MultiAgent): diff --git a/skrl/multi_agents/torch/mappo/mappo.py b/skrl/multi_agents/torch/mappo/mappo.py index c985b8df..d466c7bc 100644 --- a/skrl/multi_agents/torch/mappo/mappo.py +++ b/skrl/multi_agents/torch/mappo/mappo.py @@ -15,6 +15,7 @@ from skrl.resources.schedulers.torch import KLAdaptiveLR +# fmt: off # [start-config-dict-torch] MAPPO_DEFAULT_CONFIG = { "rollouts": 16, # number of rollouts before updating @@ -64,6 +65,7 @@ } } # [end-config-dict-torch] +# fmt: on class MAPPO(MultiAgent): diff --git a/skrl/trainers/jax/sequential.py b/skrl/trainers/jax/sequential.py index 931030be..d35c3179 100644 --- a/skrl/trainers/jax/sequential.py +++ b/skrl/trainers/jax/sequential.py @@ -12,6 +12,7 @@ from skrl.trainers.jax import Trainer +# fmt: off # [start-config-dict-jax] SEQUENTIAL_TRAINER_DEFAULT_CONFIG = { "timesteps": 100000, # number of timesteps to train for @@ -21,6 +22,7 @@ "environment_info": "episode", # key used to get and log environment info } # [end-config-dict-jax] +# fmt: on class SequentialTrainer(Trainer): diff --git a/skrl/trainers/jax/step.py b/skrl/trainers/jax/step.py index 2e27b0de..e164ccde 100644 --- a/skrl/trainers/jax/step.py +++ b/skrl/trainers/jax/step.py @@ -14,6 +14,7 @@ from skrl.trainers.jax import Trainer +# fmt: off # [start-config-dict-jax] STEP_TRAINER_DEFAULT_CONFIG = { "timesteps": 100000, # number of timesteps to train for @@ -23,6 +24,7 @@ "environment_info": "episode", # key used to get and log environment info } # [end-config-dict-jax] +# fmt: on class StepTrainer(Trainer): diff --git a/skrl/trainers/torch/parallel.py b/skrl/trainers/torch/parallel.py index cc221255..a31b3a9b 100644 --- a/skrl/trainers/torch/parallel.py +++ b/skrl/trainers/torch/parallel.py @@ -12,6 +12,7 @@ from skrl.trainers.torch import Trainer +# fmt: off # [start-config-dict-torch] PARALLEL_TRAINER_DEFAULT_CONFIG = { "timesteps": 100000, # number of timesteps to train for @@ -21,6 +22,7 @@ "environment_info": "episode", # key used to get and log environment info } # [end-config-dict-torch] +# fmt: on def fn_processor(process_index, *args): diff --git a/skrl/trainers/torch/sequential.py b/skrl/trainers/torch/sequential.py index 9921faa2..c2111229 100644 --- a/skrl/trainers/torch/sequential.py +++ b/skrl/trainers/torch/sequential.py @@ -11,6 +11,7 @@ from skrl.trainers.torch import Trainer +# fmt: off # [start-config-dict-torch] SEQUENTIAL_TRAINER_DEFAULT_CONFIG = { "timesteps": 100000, # number of timesteps to train for @@ -20,6 +21,7 @@ "environment_info": "episode", # key used to get and log environment info } # [end-config-dict-torch] +# fmt: on class SequentialTrainer(Trainer): diff --git a/skrl/trainers/torch/step.py b/skrl/trainers/torch/step.py index 7405b23e..d4783c9d 100644 --- a/skrl/trainers/torch/step.py +++ b/skrl/trainers/torch/step.py @@ -11,6 +11,7 @@ from skrl.trainers.torch import Trainer +# fmt: off # [start-config-dict-torch] STEP_TRAINER_DEFAULT_CONFIG = { "timesteps": 100000, # number of timesteps to train for @@ -20,6 +21,7 @@ "environment_info": "episode", # key used to get and log environment info } # [end-config-dict-torch] +# fmt: on class StepTrainer(Trainer): From 362ed94e5137b2f243487ada384fc5a41b97db2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 4 Nov 2024 16:48:58 -0500 Subject: [PATCH 5/8] Apply black forma to skrl folder --- skrl/__init__.py | 55 +-- skrl/agents/jax/a2c/a2c.py | 234 ++++++++----- skrl/agents/jax/base.py | 74 ++-- skrl/agents/jax/cem/cem.py | 99 ++++-- skrl/agents/jax/ddpg/ddpg.py | 178 ++++++---- skrl/agents/jax/dqn/ddqn.py | 133 +++++--- skrl/agents/jax/dqn/dqn.py | 129 ++++--- skrl/agents/jax/ppo/ppo.py | 277 +++++++++------ skrl/agents/jax/rpo/rpo.py | 293 ++++++++++------ skrl/agents/jax/sac/sac.py | 241 ++++++++----- skrl/agents/jax/td3/td3.py | 243 +++++++++----- skrl/agents/torch/a2c/a2c.py | 139 +++++--- skrl/agents/torch/a2c/a2c_rnn.py | 203 +++++++---- skrl/agents/torch/amp/amp.py | 291 ++++++++++------ skrl/agents/torch/base.py | 96 +++--- skrl/agents/torch/cem/cem.py | 98 ++++-- skrl/agents/torch/ddpg/ddpg.py | 108 +++--- skrl/agents/torch/ddpg/ddpg_rnn.py | 130 ++++--- skrl/agents/torch/dqn/ddqn.py | 122 ++++--- skrl/agents/torch/dqn/dqn.py | 112 ++++--- skrl/agents/torch/ppo/ppo.py | 162 ++++++--- skrl/agents/torch/ppo/ppo_rnn.py | 218 ++++++++---- skrl/agents/torch/q_learning/q_learning.py | 92 +++-- skrl/agents/torch/rpo/rpo.py | 171 ++++++---- skrl/agents/torch/rpo/rpo_rnn.py | 232 +++++++++---- skrl/agents/torch/sac/sac.py | 133 +++++--- skrl/agents/torch/sac/sac_rnn.py | 157 ++++++--- skrl/agents/torch/sarsa/sarsa.py | 90 +++-- skrl/agents/torch/td3/td3.py | 135 +++++--- skrl/agents/torch/td3/td3_rnn.py | 153 ++++++--- skrl/agents/torch/trpo/trpo.py | 182 ++++++---- skrl/agents/torch/trpo/trpo_rnn.py | 239 ++++++++----- skrl/envs/jax.py | 3 +- skrl/envs/loaders/jax/__init__.py | 2 +- skrl/envs/loaders/torch/__init__.py | 2 +- skrl/envs/loaders/torch/bidexhands_envs.py | 23 +- skrl/envs/loaders/torch/isaacgym_envs.py | 128 ++++--- skrl/envs/loaders/torch/isaaclab_envs.py | 28 +- .../loaders/torch/omniverse_isaacgym_envs.py | 95 ++++-- skrl/envs/torch.py | 3 +- skrl/envs/wrappers/jax/__init__.py | 5 +- skrl/envs/wrappers/jax/base.py | 45 +-- skrl/envs/wrappers/jax/bidexhands_envs.py | 41 ++- skrl/envs/wrappers/jax/brax_envs.py | 26 +- skrl/envs/wrappers/jax/gym_envs.py | 45 ++- skrl/envs/wrappers/jax/gymnasium_envs.py | 44 ++- skrl/envs/wrappers/jax/isaacgym_envs.py | 95 +++--- skrl/envs/wrappers/jax/isaaclab_envs.py | 80 +++-- .../wrappers/jax/omniverse_isaacgym_envs.py | 42 ++- skrl/envs/wrappers/jax/pettingzoo_envs.py | 46 ++- skrl/envs/wrappers/torch/__init__.py | 5 +- skrl/envs/wrappers/torch/base.py | 30 +- skrl/envs/wrappers/torch/bidexhands_envs.py | 34 +- skrl/envs/wrappers/torch/brax_envs.py | 16 +- skrl/envs/wrappers/torch/deepmind_envs.py | 53 +-- skrl/envs/wrappers/torch/gym_envs.py | 27 +- skrl/envs/wrappers/torch/gymnasium_envs.py | 26 +- skrl/envs/wrappers/torch/isaacgym_envs.py | 37 +- skrl/envs/wrappers/torch/isaaclab_envs.py | 56 ++-- .../wrappers/torch/omniverse_isaacgym_envs.py | 10 +- skrl/envs/wrappers/torch/pettingzoo_envs.py | 52 ++- skrl/envs/wrappers/torch/robosuite_envs.py | 44 ++- skrl/memories/jax/base.py | 108 ++++-- skrl/memories/jax/random.py | 18 +- skrl/memories/torch/base.py | 86 +++-- skrl/memories/torch/random.py | 26 +- skrl/models/jax/base.py | 102 +++--- skrl/models/jax/categorical.py | 23 +- skrl/models/jax/deterministic.py | 10 +- skrl/models/jax/gaussian.py | 66 ++-- skrl/models/jax/multicategorical.py | 28 +- skrl/models/torch/base.py | 87 +++-- skrl/models/torch/categorical.py | 6 +- skrl/models/torch/deterministic.py | 6 +- skrl/models/torch/gaussian.py | 31 +- skrl/models/torch/multicategorical.py | 39 ++- skrl/models/torch/multivariate_gaussian.py | 20 +- skrl/models/torch/tabular.py | 21 +- skrl/multi_agents/jax/base.py | 124 ++++--- skrl/multi_agents/jax/ippo/ippo.py | 293 ++++++++++------ skrl/multi_agents/jax/mappo/mappo.py | 317 +++++++++++------- skrl/multi_agents/torch/base.py | 112 ++++--- skrl/multi_agents/torch/ippo/ippo.py | 180 ++++++---- skrl/multi_agents/torch/mappo/mappo.py | 205 +++++++---- skrl/resources/noises/jax/base.py | 4 +- .../noises/jax/ornstein_uhlenbeck.py | 16 +- skrl/resources/noises/torch/base.py | 2 +- skrl/resources/noises/torch/gaussian.py | 6 +- .../noises/torch/ornstein_uhlenbeck.py | 22 +- skrl/resources/optimizers/jax/adam.py | 6 +- .../jax/running_standard_scaler.py | 108 +++--- .../torch/running_standard_scaler.py | 44 ++- skrl/resources/schedulers/jax/kl_adaptive.py | 19 +- .../resources/schedulers/torch/kl_adaptive.py | 28 +- skrl/trainers/jax/base.py | 125 ++++--- skrl/trainers/jax/sequential.py | 76 +++-- skrl/trainers/jax/step.py | 88 +++-- skrl/trainers/torch/base.py | 125 ++++--- skrl/trainers/torch/parallel.py | 104 +++--- skrl/trainers/torch/sequential.py | 76 +++-- skrl/trainers/torch/step.py | 78 +++-- skrl/utils/__init__.py | 2 +- skrl/utils/control.py | 39 ++- skrl/utils/distributed/jax/launcher.py | 24 +- skrl/utils/huggingface.py | 7 +- skrl/utils/isaacgym_utils.py | 171 ++++++---- .../utils/model_instantiators/jax/__init__.py | 1 + .../model_instantiators/jax/categorical.py | 30 +- skrl/utils/model_instantiators/jax/common.py | 37 +- .../model_instantiators/jax/deterministic.py | 27 +- .../utils/model_instantiators/jax/gaussian.py | 44 +-- .../model_instantiators/torch/__init__.py | 1 + .../model_instantiators/torch/categorical.py | 30 +- .../utils/model_instantiators/torch/common.py | 39 ++- .../torch/deterministic.py | 27 +- .../model_instantiators/torch/gaussian.py | 44 +-- .../torch/multivariate_gaussian.py | 44 +-- .../utils/model_instantiators/torch/shared.py | 74 ++-- skrl/utils/omniverse_isaacgym_utils.py | 81 +++-- skrl/utils/postprocessing.py | 19 +- skrl/utils/runner/jax/runner.py | 55 +-- skrl/utils/runner/torch/runner.py | 55 +-- skrl/utils/spaces/jax/__init__.py | 2 +- skrl/utils/spaces/jax/spaces.py | 12 +- skrl/utils/spaces/torch/__init__.py | 2 +- skrl/utils/spaces/torch/spaces.py | 12 +- 126 files changed, 6362 insertions(+), 3714 deletions(-) diff --git a/skrl/__init__.py b/skrl/__init__.py index 7144c3eb..a41d7944 100644 --- a/skrl/__init__.py +++ b/skrl/__init__.py @@ -13,6 +13,7 @@ # read library version from metadata try: import importlib.metadata + __version__ = importlib.metadata.version("skrl") except ImportError: __version__ = "unknown" @@ -21,15 +22,18 @@ # logger with format class _Formatter(logging.Formatter): _format = "[%(name)s:%(levelname)s] %(message)s" - _formats = {logging.DEBUG: f"\x1b[38;20m{_format}\x1b[0m", - logging.INFO: f"\x1b[38;20m{_format}\x1b[0m", - logging.WARNING: f"\x1b[33;20m{_format}\x1b[0m", - logging.ERROR: f"\x1b[31;20m{_format}\x1b[0m", - logging.CRITICAL: f"\x1b[31;1m{_format}\x1b[0m"} + _formats = { + logging.DEBUG: f"\x1b[38;20m{_format}\x1b[0m", + logging.INFO: f"\x1b[38;20m{_format}\x1b[0m", + logging.WARNING: f"\x1b[33;20m{_format}\x1b[0m", + logging.ERROR: f"\x1b[31;20m{_format}\x1b[0m", + logging.CRITICAL: f"\x1b[31;1m{_format}\x1b[0m", + } def format(self, record): return logging.Formatter(self._formats.get(record.levelno)).format(record) + _handler = logging.StreamHandler() _handler.setLevel(logging.DEBUG) _handler.setFormatter(_Formatter()) @@ -42,13 +46,11 @@ def format(self, record): # machine learning framework configuration class _Config(object): def __init__(self) -> None: - """Machine learning framework specific configuration - """ + """Machine learning framework specific configuration""" class PyTorch(object): def __init__(self) -> None: - """PyTorch configuration - """ + """PyTorch configuration""" self._device = None # torch.distributed config self._local_rank = int(os.getenv("LOCAL_RANK", "0")) @@ -59,7 +61,10 @@ def __init__(self) -> None: # set up distributed runs if self._is_distributed: import torch - logger.info(f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})") + + logger.info( + f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})" + ) torch.distributed.init_process_group("nccl", rank=self._rank, world_size=self._world_size) torch.cuda.set_device(self._local_rank) @@ -72,6 +77,7 @@ def device(self) -> "torch.device": """ try: import torch + if self._device is None: return torch.device(f"cuda:{self._local_rank}" if torch.cuda.is_available() else "cpu") return torch.device(self._device) @@ -116,8 +122,7 @@ def is_distributed(self) -> bool: class JAX(object): def __init__(self) -> None: - """JAX configuration - """ + """JAX configuration""" self._backend = "numpy" self._key = np.array([0, 0], dtype=np.uint32) # distributed config (based on torch.distributed, since JAX doesn't implement it) @@ -126,7 +131,9 @@ def __init__(self) -> None: self._local_rank = int(os.getenv("JAX_LOCAL_RANK", "0")) self._rank = int(os.getenv("JAX_RANK", "0")) self._world_size = int(os.getenv("JAX_WORLD_SIZE", "1")) - self._coordinator_address = os.getenv("JAX_COORDINATOR_ADDR", "127.0.0.1") + ":" + os.getenv("JAX_COORDINATOR_PORT", "1234") + self._coordinator_address = ( + os.getenv("JAX_COORDINATOR_ADDR", "127.0.0.1") + ":" + os.getenv("JAX_COORDINATOR_PORT", "1234") + ) self._is_distributed = self._world_size > 1 # device self._device = f"cuda:{self._local_rank}" @@ -134,11 +141,16 @@ def __init__(self) -> None: # set up distributed runs if self._is_distributed: import jax - logger.info(f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})") - jax.distributed.initialize(coordinator_address=self._coordinator_address, - num_processes=self._world_size, - process_id=self._rank, - local_device_ids=self._local_rank) + + logger.info( + f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})" + ) + jax.distributed.initialize( + coordinator_address=self._coordinator_address, + num_processes=self._world_size, + process_id=self._rank, + local_device_ids=self._local_rank, + ) @staticmethod def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device": @@ -158,7 +170,7 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device": if isinstance(device, jax.Device): return device elif isinstance(device, str): - device_type, device_index = f"{device}:0".split(':')[:2] + device_type, device_index = f"{device}:0".split(":")[:2] try: return jax.devices(device_type)[int(device_index)] except (RuntimeError, IndexError) as e: @@ -196,11 +208,11 @@ def backend(self, value: str) -> None: @property def key(self) -> "jax.Array": - """Pseudo-random number generator (PRNG) key - """ + """Pseudo-random number generator (PRNG) key""" if isinstance(self._key, np.ndarray): try: import jax + with jax.default_device(self.device): self._key = jax.random.PRNGKey(self._key[1]) except ImportError: @@ -257,4 +269,5 @@ def is_distributed(self) -> bool: self.jax = JAX() self.torch = PyTorch() + config = _Config() diff --git a/skrl/agents/jax/a2c/a2c.py b/skrl/agents/jax/a2c/a2c.py index 1af1c45e..ec3dd53c 100644 --- a/skrl/agents/jax/a2c/a2c.py +++ b/skrl/agents/jax/a2c/a2c.py @@ -60,12 +60,14 @@ # fmt: on -def compute_gae(rewards: np.ndarray, - dones: np.ndarray, - values: np.ndarray, - next_values: np.ndarray, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> np.ndarray: +def compute_gae( + rewards: np.ndarray, + dones: np.ndarray, + values: np.ndarray, + next_values: np.ndarray, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, +) -> np.ndarray: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -92,7 +94,9 @@ def compute_gae(rewards: np.ndarray, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else next_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -101,14 +105,17 @@ def compute_gae(rewards: np.ndarray, return returns, advantages + # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @jax.jit -def _compute_gae(rewards: jax.Array, - dones: jax.Array, - values: jax.Array, - next_values: jax.Array, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> jax.Array: +def _compute_gae( + rewards: jax.Array, + dones: jax.Array, + values: jax.Array, + next_values: jax.Array, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, +) -> jax.Array: advantage = 0 advantages = jnp.zeros_like(rewards) not_dones = jnp.logical_not(dones) @@ -117,7 +124,9 @@ def _compute_gae(rewards: jax.Array, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else next_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages = advantages.at[i].set(advantage) # returns computation returns = advantages + values @@ -126,18 +135,23 @@ def _compute_gae(rewards: jax.Array, return returns, advantages + @functools.partial(jax.jit, static_argnames=("policy_act", "get_entropy", "entropy_loss_scale")) -def _update_policy(policy_act, - policy_state_dict, - sampled_states, - sampled_actions, - sampled_log_prob, - sampled_advantages, - get_entropy, - entropy_loss_scale): +def _update_policy( + policy_act, + policy_state_dict, + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_advantages, + get_entropy, + entropy_loss_scale, +): # compute policy loss def _policy_loss(params): - _, next_log_prob, outputs = policy_act({"states": sampled_states, "taken_actions": sampled_actions}, "policy", params) + _, next_log_prob, outputs = policy_act( + {"states": sampled_states, "taken_actions": sampled_actions}, "policy", params + ) # compute approximate KL divergence ratio = next_log_prob - sampled_log_prob @@ -150,15 +164,15 @@ def _policy_loss(params): return -(sampled_advantages * next_log_prob).mean(), (entropy_loss, kl_divergence, outputs["stddev"]) - (policy_loss, (entropy_loss, kl_divergence, stddev)), grad = jax.value_and_grad(_policy_loss, has_aux=True)(policy_state_dict.params) + (policy_loss, (entropy_loss, kl_divergence, stddev)), grad = jax.value_and_grad(_policy_loss, has_aux=True)( + policy_state_dict.params + ) return grad, policy_loss, entropy_loss, kl_divergence, stddev + @functools.partial(jax.jit, static_argnames=("value_act")) -def _update_value(value_act, - value_state_dict, - sampled_states, - sampled_returns): +def _update_value(value_act, value_state_dict, sampled_states, sampled_returns): # compute value loss def _value_loss(params): predicted_values, _, _ = value_act({"states": sampled_states}, "value", params) @@ -170,13 +184,15 @@ def _value_loss(params): class A2C(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, jax.Device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, jax.Device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Advantage Actor Critic (A2C) https://arxiv.org/abs/1602.01783 @@ -202,12 +218,14 @@ def __init__(self, # _cfg = copy.deepcopy(A2C_DEFAULT_CONFIG) # TODO: TypeError: cannot pickle 'jax.Device' object _cfg = A2C_DEFAULT_CONFIG _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -257,13 +275,21 @@ def __init__(self, if self._learning_rate_scheduler is not None: if self._learning_rate_scheduler == KLAdaptiveLR: scale = False - self.scheduler = self._learning_rate_scheduler(self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"] + ) else: - self._learning_rate = self._learning_rate_scheduler(self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"]) + self._learning_rate = self._learning_rate_scheduler( + self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"] + ) # optimizer with jax.default_device(self.device): - self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) - self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) + self.policy_optimizer = Adam( + model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale + ) + self.value_optimizer = Adam( + model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale + ) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer self.checkpoint_modules["value_optimizer"] = self.value_optimizer @@ -282,8 +308,7 @@ def __init__(self, self._value_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -338,16 +363,18 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in return actions, log_prob, outputs - def record_transition(self, - states: Union[np.ndarray, jax.Array], - actions: Union[np.ndarray, jax.Array], - rewards: Union[np.ndarray, jax.Array], - next_states: Union[np.ndarray, jax.Array], - terminated: Union[np.ndarray, jax.Array], - truncated: Union[np.ndarray, jax.Array], - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Union[np.ndarray, jax.Array], + actions: Union[np.ndarray, jax.Array], + rewards: Union[np.ndarray, jax.Array], + next_states: Union[np.ndarray, jax.Array], + terminated: Union[np.ndarray, jax.Array], + truncated: Union[np.ndarray, jax.Array], + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -369,7 +396,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: self._current_next_states = next_states @@ -389,11 +418,27 @@ def record_transition(self, rewards += self._discount_factor * values * truncated # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -432,7 +477,9 @@ def _update(self, timestep: int, timesteps: int) -> None: """ # compute returns and advantages self.value.training = False - last_values, _, _ = self.value.act({"states": self._state_preprocessor(self._current_next_states)}, role="value") # TODO: .float() + last_values, _, _ = self.value.act( + {"states": self._state_preprocessor(self._current_next_states)}, role="value" + ) # TODO: .float() self.value.training = True if not self._jax: # numpy backend last_values = jax.device_get(last_values) @@ -440,19 +487,23 @@ def _update(self, timestep: int, timesteps: int) -> None: values = self.memory.get_tensor_by_name("values") if self._jax: - returns, advantages = _compute_gae(rewards=self.memory.get_tensor_by_name("rewards"), - dones=self.memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor, - lambda_coefficient=self._lambda) + returns, advantages = _compute_gae( + rewards=self.memory.get_tensor_by_name("rewards"), + dones=self.memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor, + lambda_coefficient=self._lambda, + ) else: - returns, advantages = compute_gae(rewards=self.memory.get_tensor_by_name("rewards"), - dones=self.memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor, - lambda_coefficient=self._lambda) + returns, advantages = compute_gae( + rewards=self.memory.get_tensor_by_name("rewards"), + dones=self.memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor, + lambda_coefficient=self._lambda, + ) self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True)) self.memory.set_tensor_by_name("returns", self._value_preprocessor(returns, train=True)) @@ -473,32 +524,35 @@ def _update(self, timestep: int, timesteps: int) -> None: sampled_states = self._state_preprocessor(sampled_states, train=True) # compute policy loss - grad, policy_loss, entropy_loss, kl_divergence, stddev = _update_policy(self.policy.act, - self.policy.state_dict, - sampled_states, - sampled_actions, - sampled_log_prob, - sampled_advantages, - self.policy.get_entropy, - self._entropy_loss_scale) + grad, policy_loss, entropy_loss, kl_divergence, stddev = _update_policy( + self.policy.act, + self.policy.state_dict, + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_advantages, + self.policy.get_entropy, + self._entropy_loss_scale, + ) kl_divergences.append(kl_divergence.item()) # optimization step (policy) if config.jax.is_distributed: grad = self.policy.reduce_parameters(grad) - self.policy_optimizer = self.policy_optimizer.step(grad, self.policy, self.scheduler._lr if self.scheduler else None) + self.policy_optimizer = self.policy_optimizer.step( + grad, self.policy, self.scheduler._lr if self.scheduler else None + ) # compute value loss - grad, value_loss = _update_value(self.value.act, - self.value.state_dict, - sampled_states, - sampled_returns) + grad, value_loss = _update_value(self.value.act, self.value.state_dict, sampled_states, sampled_returns) # optimization step (value) if config.jax.is_distributed: grad = self.value.reduce_parameters(grad) - self.value_optimizer = self.value_optimizer.step(grad, self.value, self.scheduler._lr if self.scheduler else None) + self.value_optimizer = self.value_optimizer.step( + grad, self.value, self.scheduler._lr if self.scheduler else None + ) # update cumulative losses cumulative_policy_loss += policy_loss.item() @@ -512,7 +566,7 @@ def _update(self, timestep: int, timesteps: int) -> None: kl = np.mean(kl_divergences) # reduce (collect from all workers/processes) KL in distributed runs if config.jax.is_distributed: - kl = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(kl.reshape(1)).item() + kl = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(kl.reshape(1)).item() kl /= config.jax.world_size self.scheduler.step(kl) diff --git a/skrl/agents/jax/base.py b/skrl/agents/jax/base.py index 1d9ddc8f..d11370e7 100644 --- a/skrl/agents/jax/base.py +++ b/skrl/agents/jax/base.py @@ -17,13 +17,15 @@ class Agent: - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, jax.Device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, jax.Device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Base class that represent a RL agent :param models: Models used by the agent @@ -54,7 +56,7 @@ def __init__(self, else: self.device = device if type(device) == str: - device_type, device_index = f"{device}:0".split(':')[:2] + device_type, device_index = f"{device}:0".split(":")[:2] self.device = jax.devices(device_type)[int(device_index)] if type(memory) is list: @@ -83,7 +85,7 @@ def __init__(self, self.checkpoint_modules = {} self.checkpoint_interval = self.cfg.get("experiment", {}).get("checkpoint_interval", "auto") self.checkpoint_store_separately = self.cfg.get("experiment", {}).get("store_separately", False) - self.checkpoint_best_modules = {"timestep": 0, "reward": -2 ** 31, "saved": False, "modules": {}} + self.checkpoint_best_modules = {"timestep": 0, "reward": -(2**31), "saved": False, "modules": {}} # experiment directory directory = self.cfg.get("experiment", {}).get("directory", "") @@ -91,7 +93,9 @@ def __init__(self, if not directory: directory = os.path.join(os.getcwd(), "runs") if not experiment_name: - experiment_name = "{}_{}".format(datetime.datetime.now().strftime("%y-%m-%d_%H-%M-%S-%f"), self.__class__.__name__) + experiment_name = "{}_{}".format( + datetime.datetime.now().strftime("%y-%m-%d_%H-%M-%S-%f"), self.__class__.__name__ + ) self.experiment_dir = os.path.join(directory, experiment_name) def __str__(self) -> str: @@ -158,7 +162,7 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: models_cfg = {k: v.net._modules for (k, v) in self.models.items()} except AttributeError: models_cfg = {k: v._modules for (k, v) in self.models.items()} - wandb_config={**self.cfg, **trainer_cfg, **models_cfg} + wandb_config = {**self.cfg, **trainer_cfg, **models_cfg} # set default values wandb_kwargs = copy.deepcopy(self.cfg.get("experiment", {}).get("wandb_kwargs", {})) wandb_kwargs.setdefault("name", os.path.split(self.experiment_dir)[-1]) @@ -167,6 +171,7 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: wandb_kwargs["config"].update(wandb_config) # init Weights & Biases import wandb + wandb.init(**wandb_kwargs) # main entry to log data for consumption and visualization by TensorBoard @@ -177,6 +182,7 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: # tensorboard via torch SummaryWriter try: from torch.utils.tensorboard import SummaryWriter + self.writer = SummaryWriter(log_dir=self.experiment_dir) except ImportError as e: pass @@ -200,6 +206,7 @@ def add_scalar(self, tag, value, step): if self.writer is None: try: import tensorboardX + self.writer = tensorboardX.SummaryWriter(log_dir=self.experiment_dir) except ImportError as e: pass @@ -283,7 +290,9 @@ def write_checkpoint(self, timestep: int, timesteps: int) -> None: if self.checkpoint_store_separately: for name, module in self.checkpoint_modules.items(): with open(os.path.join(self.experiment_dir, "checkpoints", f"best_{name}.pickle"), "wb") as file: - pickle.dump(flax.serialization.to_bytes(self.checkpoint_best_modules["modules"][name]), file, protocol=4) + pickle.dump( + flax.serialization.to_bytes(self.checkpoint_best_modules["modules"][name]), file, protocol=4 + ) # whole agent else: modules = {} @@ -310,16 +319,18 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in """ raise NotImplementedError - def record_transition(self, - states: Union[np.ndarray, jax.Array], - actions: Union[np.ndarray, jax.Array], - rewards: Union[np.ndarray, jax.Array], - next_states: Union[np.ndarray, jax.Array], - terminated: Union[np.ndarray, jax.Array], - truncated: Union[np.ndarray, jax.Array], - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Union[np.ndarray, jax.Array], + actions: Union[np.ndarray, jax.Array], + rewards: Union[np.ndarray, jax.Array], + next_states: Union[np.ndarray, jax.Array], + terminated: Union[np.ndarray, jax.Array], + truncated: Union[np.ndarray, jax.Array], + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory (to be implemented by the inheriting classes) Inheriting classes must call this method to record episode information (rewards, timesteps, etc.). @@ -444,11 +455,13 @@ def load(self, path: str) -> None: else: logger.warning(f"Cannot load the {name} module. The agent doesn't have such an instance") - def migrate(self, - path: str, - name_map: Mapping[str, Mapping[str, str]] = {}, - auto_mapping: bool = True, - verbose: bool = False) -> bool: + def migrate( + self, + path: str, + name_map: Mapping[str, Mapping[str, str]] = {}, + auto_mapping: bool = True, + verbose: bool = False, + ) -> bool: """Migrate the specified external checkpoint to the current agent :raises NotImplementedError: Not yet implemented @@ -478,14 +491,15 @@ def post_interaction(self, timestep: int, timesteps: int) -> None: # update best models and write checkpoints if timestep > 1 and self.checkpoint_interval > 0 and not timestep % self.checkpoint_interval: # update best models - reward = np.mean(self.tracking_data.get("Reward / Total reward (mean)", -2 ** 31)) + reward = np.mean(self.tracking_data.get("Reward / Total reward (mean)", -(2**31))) if reward > self.checkpoint_best_modules["reward"]: self.checkpoint_best_modules["timestep"] = timestep self.checkpoint_best_modules["reward"] = reward self.checkpoint_best_modules["saved"] = False with jax.default_device(self.device): - self.checkpoint_best_modules["modules"] = \ - {k: copy.deepcopy(self._get_internal_value(v)) for k, v in self.checkpoint_modules.items()} + self.checkpoint_best_modules["modules"] = { + k: copy.deepcopy(self._get_internal_value(v)) for k, v in self.checkpoint_modules.items() + } # write checkpoints self.write_checkpoint(timestep, timesteps) diff --git a/skrl/agents/jax/cem/cem.py b/skrl/agents/jax/cem/cem.py index 3c74c634..c6e0c610 100644 --- a/skrl/agents/jax/cem/cem.py +++ b/skrl/agents/jax/cem/cem.py @@ -52,13 +52,15 @@ class CEM(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, jax.Device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, jax.Device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Cross-Entropy Method (CEM) https://ieeexplore.ieee.org/abstract/document/6796865/ @@ -84,12 +86,14 @@ def __init__(self, # _cfg = copy.deepcopy(CEM_DEFAULT_CONFIG) # TODO: TypeError: cannot pickle 'jax.Device' object _cfg = CEM_DEFAULT_CONFIG _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -121,7 +125,9 @@ def __init__(self, with jax.default_device(self.device): self.optimizer = Adam(model=self.policy, lr=self._learning_rate) if self._learning_rate_scheduler is not None: - self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["optimizer"] = self.optimizer @@ -133,8 +139,7 @@ def __init__(self, self._state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -176,16 +181,18 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in return actions, None, outputs - def record_transition(self, - states: Union[np.ndarray, jax.Array], - actions: Union[np.ndarray, jax.Array], - rewards: Union[np.ndarray, jax.Array], - next_states: Union[np.ndarray, jax.Array], - terminated: Union[np.ndarray, jax.Array], - truncated: Union[np.ndarray, jax.Array], - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Union[np.ndarray, jax.Array], + actions: Union[np.ndarray, jax.Array], + rewards: Union[np.ndarray, jax.Array], + next_states: Union[np.ndarray, jax.Array], + terminated: Union[np.ndarray, jax.Array], + truncated: Union[np.ndarray, jax.Array], + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -207,7 +214,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: # reward shaping @@ -215,11 +224,23 @@ def record_transition(self, rewards = self._rewards_shaper(rewards, timestep, timesteps) # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) # track episodes internally if self._rollout: @@ -282,9 +303,13 @@ def _update(self, timestep: int, timesteps: int) -> None: for e in range(sampled_rewards.shape[-1]): for i, j in zip(self._episode_tracking[e][:-1], self._episode_tracking[e][1:]): limits.append([e + i, e + j]) - rewards = sampled_rewards[e + i: e + j] - returns.append(np.sum(rewards * self._discount_factor ** \ - np.flip(np.arange(rewards.shape[0]), axis=-1).reshape(rewards.shape))) + rewards = sampled_rewards[e + i : e + j] + returns.append( + np.sum( + rewards + * self._discount_factor ** np.flip(np.arange(rewards.shape[0]), axis=-1).reshape(rewards.shape) + ) + ) if not len(returns): logger.warning("No returns to update. Consider increasing the number of rollouts") @@ -295,8 +320,10 @@ def _update(self, timestep: int, timesteps: int) -> None: # get elite states and actions indexes = (returns >= return_threshold).nonzero()[0] - elite_states = np.concatenate([sampled_states[limits[i][0]:limits[i][1]] for i in indexes], axis=0) - elite_actions = np.concatenate([sampled_actions[limits[i][0]:limits[i][1]] for i in indexes], axis=0).reshape(-1) + elite_states = np.concatenate([sampled_states[limits[i][0] : limits[i][1]] for i in indexes], axis=0) + elite_actions = np.concatenate([sampled_actions[limits[i][0] : limits[i][1]] for i in indexes], axis=0).reshape( + -1 + ) # compute policy loss def _policy_loss(params): diff --git a/skrl/agents/jax/ddpg/ddpg.py b/skrl/agents/jax/ddpg/ddpg.py index a00646ac..bab901c5 100644 --- a/skrl/agents/jax/ddpg/ddpg.py +++ b/skrl/agents/jax/ddpg/ddpg.py @@ -64,23 +64,24 @@ # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @jax.jit -def _apply_exploration_noise(actions: jax.Array, - noises: jax.Array, - clip_actions_min: jax.Array, - clip_actions_max: jax.Array, - scale: float) -> jax.Array: +def _apply_exploration_noise( + actions: jax.Array, noises: jax.Array, clip_actions_min: jax.Array, clip_actions_max: jax.Array, scale: float +) -> jax.Array: noises = noises.at[:].multiply(scale) return jnp.clip(actions + noises, a_min=clip_actions_min, a_max=clip_actions_max), noises + @functools.partial(jax.jit, static_argnames=("critic_act")) -def _update_critic(critic_act, - critic_state_dict, - target_q_values: jax.Array, - sampled_states: Union[np.ndarray, jax.Array], - sampled_actions: Union[np.ndarray, jax.Array], - sampled_rewards: Union[np.ndarray, jax.Array], - sampled_dones: Union[np.ndarray, jax.Array], - discount_factor: float): +def _update_critic( + critic_act, + critic_state_dict, + target_q_values: jax.Array, + sampled_states: Union[np.ndarray, jax.Array], + sampled_actions: Union[np.ndarray, jax.Array], + sampled_rewards: Union[np.ndarray, jax.Array], + sampled_dones: Union[np.ndarray, jax.Array], + discount_factor: float, +): # compute target values target_values = sampled_rewards + discount_factor * jnp.logical_not(sampled_dones) * target_q_values @@ -94,31 +95,32 @@ def _critic_loss(params): return grad, critic_loss, critic_values, target_values + @functools.partial(jax.jit, static_argnames=("policy_act", "critic_act")) -def _update_policy(policy_act, - critic_act, - policy_state_dict, - critic_state_dict, - sampled_states): +def _update_policy(policy_act, critic_act, policy_state_dict, critic_state_dict, sampled_states): # compute policy (actor) loss def _policy_loss(policy_params, critic_params): actions, _, _ = policy_act({"states": sampled_states}, "policy", policy_params) critic_values, _, _ = critic_act({"states": sampled_states, "taken_actions": actions}, "critic", critic_params) return -critic_values.mean() - policy_loss, grad = jax.value_and_grad(_policy_loss, has_aux=False)(policy_state_dict.params, critic_state_dict.params) + policy_loss, grad = jax.value_and_grad(_policy_loss, has_aux=False)( + policy_state_dict.params, critic_state_dict.params + ) return grad, policy_loss class DDPG(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, jax.Device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, jax.Device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Deep Deterministic Policy Gradient (DDPG) https://arxiv.org/abs/1509.02971 @@ -144,12 +146,14 @@ def __init__(self, # _cfg = copy.deepcopy(DDPG_DEFAULT_CONFIG) # TODO: TypeError: cannot pickle 'jax.Device' object _cfg = DDPG_DEFAULT_CONFIG _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -199,11 +203,19 @@ def __init__(self, # set up optimizers and learning rate schedulers if self.policy is not None and self.critic is not None: with jax.default_device(self.device): - self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip) - self.critic_optimizer = Adam(model=self.critic, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip) + self.policy_optimizer = Adam( + model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip + ) + self.critic_optimizer = Adam( + model=self.critic, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip + ) if self._learning_rate_scheduler is not None: - self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) - self.critic_scheduler = self._learning_rate_scheduler(self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.policy_scheduler = self._learning_rate_scheduler( + self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) + self.critic_scheduler = self._learning_rate_scheduler( + self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer self.checkpoint_modules["critic_optimizer"] = self.critic_optimizer @@ -226,8 +238,7 @@ def __init__(self, self._state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -292,13 +303,15 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in # apply exploration noise if timestep <= self._exploration_timesteps: - scale = (1 - timestep / self._exploration_timesteps) \ - * (self._exploration_initial_scale - self._exploration_final_scale) \ - + self._exploration_final_scale + scale = (1 - timestep / self._exploration_timesteps) * ( + self._exploration_initial_scale - self._exploration_final_scale + ) + self._exploration_final_scale # modify actions if self._jax: - actions, noises = _apply_exploration_noise(actions, noises, self.clip_actions_min, self.clip_actions_max, scale) + actions, noises = _apply_exploration_noise( + actions, noises, self.clip_actions_min, self.clip_actions_max, scale + ) else: noises *= scale actions = np.clip(actions + noises, a_min=self.clip_actions_min, a_max=self.clip_actions_max) @@ -316,16 +329,18 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in return actions, None, outputs - def record_transition(self, - states: Union[np.ndarray, jax.Array], - actions: Union[np.ndarray, jax.Array], - rewards: Union[np.ndarray, jax.Array], - next_states: Union[np.ndarray, jax.Array], - terminated: Union[np.ndarray, jax.Array], - truncated: Union[np.ndarray, jax.Array], - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Union[np.ndarray, jax.Array], + actions: Union[np.ndarray, jax.Array], + rewards: Union[np.ndarray, jax.Array], + next_states: Union[np.ndarray, jax.Array], + terminated: Union[np.ndarray, jax.Array], + truncated: Union[np.ndarray, jax.Array], + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -347,7 +362,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: # reward shaping @@ -355,11 +372,23 @@ def record_transition(self, rewards = self._rewards_shaper(rewards, timestep, timesteps) # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -400,8 +429,9 @@ def _update(self, timestep: int, timesteps: int) -> None: for gradient_step in range(self._gradient_steps): # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = self.memory.sample( + names=self._tensors_names, batch_size=self._batch_size + )[0] sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) @@ -409,17 +439,21 @@ def _update(self, timestep: int, timesteps: int) -> None: # compute target values next_actions, _, _ = self.target_policy.act({"states": sampled_next_states}, role="target_policy") - target_q_values, _, _ = self.target_critic.act({"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic") + target_q_values, _, _ = self.target_critic.act( + {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic" + ) # compute critic loss - grad, critic_loss, critic_values, target_values = _update_critic(self.critic.act, - self.critic.state_dict, - target_q_values, - sampled_states, - sampled_actions, - sampled_rewards, - sampled_dones, - self._discount_factor) + grad, critic_loss, critic_values, target_values = _update_critic( + self.critic.act, + self.critic.state_dict, + target_q_values, + sampled_states, + sampled_actions, + sampled_rewards, + sampled_dones, + self._discount_factor, + ) # optimization step (critic) if config.jax.is_distributed: @@ -427,11 +461,9 @@ def _update(self, timestep: int, timesteps: int) -> None: self.critic_optimizer = self.critic_optimizer.step(grad, self.critic) # compute policy (actor) loss - grad, policy_loss = _update_policy(self.policy.act, - self.critic.act, - self.policy.state_dict, - self.critic.state_dict, - sampled_states) + grad, policy_loss = _update_policy( + self.policy.act, self.critic.act, self.policy.state_dict, self.critic.state_dict, sampled_states + ) # optimization step (policy) if config.jax.is_distributed: diff --git a/skrl/agents/jax/dqn/ddqn.py b/skrl/agents/jax/dqn/ddqn.py index 6b4f2750..cf73b2c0 100644 --- a/skrl/agents/jax/dqn/ddqn.py +++ b/skrl/agents/jax/dqn/ddqn.py @@ -63,15 +63,17 @@ # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @functools.partial(jax.jit, static_argnames=("q_network_act")) -def _update_q_network(q_network_act, - q_network_state_dict, - next_q_values, - sampled_states, - sampled_next_states, - sampled_actions, - sampled_rewards, - sampled_dones, - discount_factor): +def _update_q_network( + q_network_act, + q_network_state_dict, + next_q_values, + sampled_states, + sampled_next_states, + sampled_actions, + sampled_rewards, + sampled_dones, + discount_factor, +): # compute target values q_values = q_network_act({"states": sampled_next_states}, "q_network")[0] actions = jnp.argmax(q_values, axis=-1, keepdims=True) @@ -90,13 +92,15 @@ def _q_network_loss(params): class DDQN(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, jax.Device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, jax.Device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Double Deep Q-Network (DDQN) https://ojs.aaai.org/index.php/AAAI/article/view/10295 @@ -122,12 +126,14 @@ def __init__(self, # _cfg = copy.deepcopy(DDQN_DEFAULT_CONFIG) # TODO: TypeError: cannot pickle 'jax.Device' object _cfg = DDQN_DEFAULT_CONFIG _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.q_network = self.models.get("q_network", None) @@ -172,7 +178,9 @@ def __init__(self, with jax.default_device(self.device): self.optimizer = Adam(model=self.q_network, lr=self._learning_rate) if self._learning_rate_scheduler is not None: - self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["optimizer"] = self.optimizer @@ -192,8 +200,7 @@ def __init__(self, self._state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) # create tensors in memory @@ -242,8 +249,9 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in return actions, None, outputs # sample actions with epsilon-greedy policy - epsilon = self._exploration_final_epsilon + (self._exploration_initial_epsilon - self._exploration_final_epsilon) \ - * np.exp(-1.0 * timestep / self._exploration_timesteps) + epsilon = self._exploration_final_epsilon + ( + self._exploration_initial_epsilon - self._exploration_final_epsilon + ) * np.exp(-1.0 * timestep / self._exploration_timesteps) indexes = (np.random.random(states.shape[0]) >= epsilon).nonzero()[0] if indexes.size: @@ -260,16 +268,18 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in return actions, None, outputs - def record_transition(self, - states: Union[np.ndarray, jax.Array], - actions: Union[np.ndarray, jax.Array], - rewards: Union[np.ndarray, jax.Array], - next_states: Union[np.ndarray, jax.Array], - terminated: Union[np.ndarray, jax.Array], - truncated: Union[np.ndarray, jax.Array], - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Union[np.ndarray, jax.Array], + actions: Union[np.ndarray, jax.Array], + rewards: Union[np.ndarray, jax.Array], + next_states: Union[np.ndarray, jax.Array], + terminated: Union[np.ndarray, jax.Array], + truncated: Union[np.ndarray, jax.Array], + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -291,18 +301,32 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: # reward shaping if self._rewards_shaper is not None: rewards = self._rewards_shaper(rewards, timestep, timesteps) - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -343,8 +367,9 @@ def _update(self, timestep: int, timesteps: int) -> None: for gradient_step in range(self._gradient_steps): # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = self.memory.sample( + names=self.tensors_names, batch_size=self._batch_size + )[0] sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) @@ -352,15 +377,17 @@ def _update(self, timestep: int, timesteps: int) -> None: # compute target values next_q_values, _, _ = self.target_q_network.act({"states": sampled_next_states}, role="target_q_network") - grad, q_network_loss, target_values = _update_q_network(self.q_network.act, - self.q_network.state_dict, - next_q_values, - sampled_states, - sampled_next_states, - sampled_actions, - sampled_rewards, - sampled_dones, - self._discount_factor) + grad, q_network_loss, target_values = _update_q_network( + self.q_network.act, + self.q_network.state_dict, + next_q_values, + sampled_states, + sampled_next_states, + sampled_actions, + sampled_rewards, + sampled_dones, + self._discount_factor, + ) # optimization step (Q-network) if config.jax.is_distributed: diff --git a/skrl/agents/jax/dqn/dqn.py b/skrl/agents/jax/dqn/dqn.py index e82fe82f..d88f6135 100644 --- a/skrl/agents/jax/dqn/dqn.py +++ b/skrl/agents/jax/dqn/dqn.py @@ -63,14 +63,16 @@ # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @functools.partial(jax.jit, static_argnames=("q_network_act")) -def _update_q_network(q_network_act, - q_network_state_dict, - next_q_values, - sampled_states, - sampled_actions, - sampled_rewards, - sampled_dones, - discount_factor): +def _update_q_network( + q_network_act, + q_network_state_dict, + next_q_values, + sampled_states, + sampled_actions, + sampled_rewards, + sampled_dones, + discount_factor, +): # compute target values target_q_values = jnp.max(next_q_values, axis=-1, keepdims=True) target_values = sampled_rewards + discount_factor * jnp.logical_not(sampled_dones) * target_q_values @@ -87,13 +89,15 @@ def _q_network_loss(params): class DQN(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, jax.Device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, jax.Device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Deep Q-Network (DQN) https://arxiv.org/abs/1312.5602 @@ -119,12 +123,14 @@ def __init__(self, # _cfg = copy.deepcopy(DQN_DEFAULT_CONFIG) # TODO: TypeError: cannot pickle 'jax.Device' object _cfg = DQN_DEFAULT_CONFIG _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.q_network = self.models.get("q_network", None) @@ -169,7 +175,9 @@ def __init__(self, with jax.default_device(self.device): self.optimizer = Adam(model=self.q_network, lr=self._learning_rate) if self._learning_rate_scheduler is not None: - self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["optimizer"] = self.optimizer @@ -189,8 +197,7 @@ def __init__(self, self._state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) # create tensors in memory @@ -239,8 +246,9 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in return actions, None, outputs # sample actions with epsilon-greedy policy - epsilon = self._exploration_final_epsilon + (self._exploration_initial_epsilon - self._exploration_final_epsilon) \ - * np.exp(-1.0 * timestep / self._exploration_timesteps) + epsilon = self._exploration_final_epsilon + ( + self._exploration_initial_epsilon - self._exploration_final_epsilon + ) * np.exp(-1.0 * timestep / self._exploration_timesteps) indexes = (np.random.random(states.shape[0]) >= epsilon).nonzero()[0] if indexes.size: @@ -257,16 +265,18 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in return actions, None, outputs - def record_transition(self, - states: Union[np.ndarray, jax.Array], - actions: Union[np.ndarray, jax.Array], - rewards: Union[np.ndarray, jax.Array], - next_states: Union[np.ndarray, jax.Array], - terminated: Union[np.ndarray, jax.Array], - truncated: Union[np.ndarray, jax.Array], - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Union[np.ndarray, jax.Array], + actions: Union[np.ndarray, jax.Array], + rewards: Union[np.ndarray, jax.Array], + next_states: Union[np.ndarray, jax.Array], + terminated: Union[np.ndarray, jax.Array], + truncated: Union[np.ndarray, jax.Array], + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -288,18 +298,32 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: # reward shaping if self._rewards_shaper is not None: rewards = self._rewards_shaper(rewards, timestep, timesteps) - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -340,8 +364,9 @@ def _update(self, timestep: int, timesteps: int) -> None: for gradient_step in range(self._gradient_steps): # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = self.memory.sample( + names=self.tensors_names, batch_size=self._batch_size + )[0] sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) @@ -349,14 +374,16 @@ def _update(self, timestep: int, timesteps: int) -> None: # compute target values next_q_values, _, _ = self.target_q_network.act({"states": sampled_next_states}, role="target_q_network") - grad, q_network_loss, target_values = _update_q_network(self.q_network.act, - self.q_network.state_dict, - next_q_values, - sampled_states, - sampled_actions, - sampled_rewards, - sampled_dones, - self._discount_factor) + grad, q_network_loss, target_values = _update_q_network( + self.q_network.act, + self.q_network.state_dict, + next_q_values, + sampled_states, + sampled_actions, + sampled_rewards, + sampled_dones, + self._discount_factor, + ) # optimization step (Q-network) if config.jax.is_distributed: diff --git a/skrl/agents/jax/ppo/ppo.py b/skrl/agents/jax/ppo/ppo.py index bf52b512..fd5858c1 100644 --- a/skrl/agents/jax/ppo/ppo.py +++ b/skrl/agents/jax/ppo/ppo.py @@ -67,12 +67,14 @@ # fmt: on -def compute_gae(rewards: np.ndarray, - dones: np.ndarray, - values: np.ndarray, - next_values: np.ndarray, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> np.ndarray: +def compute_gae( + rewards: np.ndarray, + dones: np.ndarray, + values: np.ndarray, + next_values: np.ndarray, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, +) -> np.ndarray: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -99,7 +101,9 @@ def compute_gae(rewards: np.ndarray, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else next_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -108,14 +112,17 @@ def compute_gae(rewards: np.ndarray, return returns, advantages + # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @jax.jit -def _compute_gae(rewards: jax.Array, - dones: jax.Array, - values: jax.Array, - next_values: jax.Array, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> jax.Array: +def _compute_gae( + rewards: jax.Array, + dones: jax.Array, + values: jax.Array, + next_values: jax.Array, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, +) -> jax.Array: advantage = 0 advantages = jnp.zeros_like(rewards) not_dones = jnp.logical_not(dones) @@ -124,7 +131,9 @@ def _compute_gae(rewards: jax.Array, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else next_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages = advantages.at[i].set(advantage) # returns computation returns = advantages + values @@ -133,19 +142,24 @@ def _compute_gae(rewards: jax.Array, return returns, advantages + @functools.partial(jax.jit, static_argnames=("policy_act", "get_entropy", "entropy_loss_scale")) -def _update_policy(policy_act, - policy_state_dict, - sampled_states, - sampled_actions, - sampled_log_prob, - sampled_advantages, - ratio_clip, - get_entropy, - entropy_loss_scale): +def _update_policy( + policy_act, + policy_state_dict, + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_advantages, + ratio_clip, + get_entropy, + entropy_loss_scale, +): # compute policy loss def _policy_loss(params): - _, next_log_prob, outputs = policy_act({"states": sampled_states, "taken_actions": sampled_actions}, "policy", params) + _, next_log_prob, outputs = policy_act( + {"states": sampled_states, "taken_actions": sampled_actions}, "policy", params + ) # compute approximate KL divergence ratio = next_log_prob - sampled_log_prob @@ -163,19 +177,24 @@ def _policy_loss(params): return -jnp.minimum(surrogate, surrogate_clipped).mean(), (entropy_loss, kl_divergence, outputs["stddev"]) - (policy_loss, (entropy_loss, kl_divergence, stddev)), grad = jax.value_and_grad(_policy_loss, has_aux=True)(policy_state_dict.params) + (policy_loss, (entropy_loss, kl_divergence, stddev)), grad = jax.value_and_grad(_policy_loss, has_aux=True)( + policy_state_dict.params + ) return grad, policy_loss, entropy_loss, kl_divergence, stddev + @functools.partial(jax.jit, static_argnames=("value_act", "clip_predicted_values")) -def _update_value(value_act, - value_state_dict, - sampled_states, - sampled_values, - sampled_returns, - value_loss_scale, - clip_predicted_values, - value_clip): +def _update_value( + value_act, + value_state_dict, + sampled_states, + sampled_values, + sampled_returns, + value_loss_scale, + clip_predicted_values, + value_clip, +): # compute value loss def _value_loss(params): predicted_values, _, _ = value_act({"states": sampled_states}, "value", params) @@ -189,13 +208,15 @@ def _value_loss(params): class PPO(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, jax.Device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, jax.Device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Proximal Policy Optimization (PPO) https://arxiv.org/abs/1707.06347 @@ -221,12 +242,14 @@ def __init__(self, # _cfg = copy.deepcopy(PPO_DEFAULT_CONFIG) # TODO: TypeError: cannot pickle 'jax.Device' object _cfg = PPO_DEFAULT_CONFIG _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -283,13 +306,21 @@ def __init__(self, if self._learning_rate_scheduler is not None: if self._learning_rate_scheduler == KLAdaptiveLR: scale = False - self.scheduler = self._learning_rate_scheduler(self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"] + ) else: - self._learning_rate = self._learning_rate_scheduler(self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"]) + self._learning_rate = self._learning_rate_scheduler( + self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"] + ) # optimizer with jax.default_device(self.device): - self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) - self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) + self.policy_optimizer = Adam( + model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale + ) + self.value_optimizer = Adam( + model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale + ) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer self.checkpoint_modules["value_optimizer"] = self.value_optimizer @@ -308,8 +339,7 @@ def __init__(self, self._value_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -364,16 +394,18 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in return actions, log_prob, outputs - def record_transition(self, - states: Union[np.ndarray, jax.Array], - actions: Union[np.ndarray, jax.Array], - rewards: Union[np.ndarray, jax.Array], - next_states: Union[np.ndarray, jax.Array], - terminated: Union[np.ndarray, jax.Array], - truncated: Union[np.ndarray, jax.Array], - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Union[np.ndarray, jax.Array], + actions: Union[np.ndarray, jax.Array], + rewards: Union[np.ndarray, jax.Array], + next_states: Union[np.ndarray, jax.Array], + terminated: Union[np.ndarray, jax.Array], + truncated: Union[np.ndarray, jax.Array], + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -395,7 +427,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: self._current_next_states = next_states @@ -415,11 +449,27 @@ def record_transition(self, rewards += self._discount_factor * values * truncated # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -458,7 +508,9 @@ def _update(self, timestep: int, timesteps: int) -> None: """ # compute returns and advantages self.value.training = False - last_values, _, _ = self.value.act({"states": self._state_preprocessor(self._current_next_states)}, role="value") # TODO: .float() + last_values, _, _ = self.value.act( + {"states": self._state_preprocessor(self._current_next_states)}, role="value" + ) # TODO: .float() self.value.training = True if not self._jax: # numpy backend last_values = jax.device_get(last_values) @@ -466,19 +518,23 @@ def _update(self, timestep: int, timesteps: int) -> None: values = self.memory.get_tensor_by_name("values") if self._jax: - returns, advantages = _compute_gae(rewards=self.memory.get_tensor_by_name("rewards"), - dones=self.memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor, - lambda_coefficient=self._lambda) + returns, advantages = _compute_gae( + rewards=self.memory.get_tensor_by_name("rewards"), + dones=self.memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor, + lambda_coefficient=self._lambda, + ) else: - returns, advantages = compute_gae(rewards=self.memory.get_tensor_by_name("rewards"), - dones=self.memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor, - lambda_coefficient=self._lambda) + returns, advantages = compute_gae( + rewards=self.memory.get_tensor_by_name("rewards"), + dones=self.memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor, + lambda_coefficient=self._lambda, + ) self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True)) self.memory.set_tensor_by_name("returns", self._value_preprocessor(returns, train=True)) @@ -496,20 +552,29 @@ def _update(self, timestep: int, timesteps: int) -> None: kl_divergences = [] # mini-batches loop - for sampled_states, sampled_actions, sampled_log_prob, sampled_values, sampled_returns, sampled_advantages in sampled_batches: + for ( + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_values, + sampled_returns, + sampled_advantages, + ) in sampled_batches: sampled_states = self._state_preprocessor(sampled_states, train=not epoch) # compute policy loss - grad, policy_loss, entropy_loss, kl_divergence, stddev = _update_policy(self.policy.act, - self.policy.state_dict, - sampled_states, - sampled_actions, - sampled_log_prob, - sampled_advantages, - self._ratio_clip, - self.policy.get_entropy, - self._entropy_loss_scale) + grad, policy_loss, entropy_loss, kl_divergence, stddev = _update_policy( + self.policy.act, + self.policy.state_dict, + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_advantages, + self._ratio_clip, + self.policy.get_entropy, + self._entropy_loss_scale, + ) kl_divergences.append(kl_divergence.item()) @@ -520,22 +585,28 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (policy) if config.jax.is_distributed: grad = self.policy.reduce_parameters(grad) - self.policy_optimizer = self.policy_optimizer.step(grad, self.policy, self.scheduler._lr if self.scheduler else None) + self.policy_optimizer = self.policy_optimizer.step( + grad, self.policy, self.scheduler._lr if self.scheduler else None + ) # compute value loss - grad, value_loss = _update_value(self.value.act, - self.value.state_dict, - sampled_states, - sampled_values, - sampled_returns, - self._value_loss_scale, - self._clip_predicted_values, - self._value_clip) + grad, value_loss = _update_value( + self.value.act, + self.value.state_dict, + sampled_states, + sampled_values, + sampled_returns, + self._value_loss_scale, + self._clip_predicted_values, + self._value_clip, + ) # optimization step (value) if config.jax.is_distributed: grad = self.value.reduce_parameters(grad) - self.value_optimizer = self.value_optimizer.step(grad, self.value, self.scheduler._lr if self.scheduler else None) + self.value_optimizer = self.value_optimizer.step( + grad, self.value, self.scheduler._lr if self.scheduler else None + ) # update cumulative losses cumulative_policy_loss += policy_loss.item() @@ -549,7 +620,7 @@ def _update(self, timestep: int, timesteps: int) -> None: kl = np.mean(kl_divergences) # reduce (collect from all workers/processes) KL in distributed runs if config.jax.is_distributed: - kl = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(kl.reshape(1)).item() + kl = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(kl.reshape(1)).item() kl /= config.jax.world_size self.scheduler.step(kl) @@ -557,7 +628,9 @@ def _update(self, timestep: int, timesteps: int) -> None: self.track_data("Loss / Policy loss", cumulative_policy_loss / (self._learning_epochs * self._mini_batches)) self.track_data("Loss / Value loss", cumulative_value_loss / (self._learning_epochs * self._mini_batches)) if self._entropy_loss_scale: - self.track_data("Loss / Entropy loss", cumulative_entropy_loss / (self._learning_epochs * self._mini_batches)) + self.track_data( + "Loss / Entropy loss", cumulative_entropy_loss / (self._learning_epochs * self._mini_batches) + ) self.track_data("Policy / Standard deviation", stddev.mean().item()) diff --git a/skrl/agents/jax/rpo/rpo.py b/skrl/agents/jax/rpo/rpo.py index 98ed8a82..d61f80a9 100644 --- a/skrl/agents/jax/rpo/rpo.py +++ b/skrl/agents/jax/rpo/rpo.py @@ -68,12 +68,14 @@ # fmt: on -def compute_gae(rewards: np.ndarray, - dones: np.ndarray, - values: np.ndarray, - next_values: np.ndarray, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> np.ndarray: +def compute_gae( + rewards: np.ndarray, + dones: np.ndarray, + values: np.ndarray, + next_values: np.ndarray, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, +) -> np.ndarray: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -100,7 +102,9 @@ def compute_gae(rewards: np.ndarray, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else next_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -109,14 +113,17 @@ def compute_gae(rewards: np.ndarray, return returns, advantages + # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @jax.jit -def _compute_gae(rewards: jax.Array, - dones: jax.Array, - values: jax.Array, - next_values: jax.Array, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> jax.Array: +def _compute_gae( + rewards: jax.Array, + dones: jax.Array, + values: jax.Array, + next_values: jax.Array, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, +) -> jax.Array: advantage = 0 advantages = jnp.zeros_like(rewards) not_dones = jnp.logical_not(dones) @@ -125,7 +132,9 @@ def _compute_gae(rewards: jax.Array, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else next_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages = advantages.at[i].set(advantage) # returns computation returns = advantages + values @@ -134,20 +143,25 @@ def _compute_gae(rewards: jax.Array, return returns, advantages + @functools.partial(jax.jit, static_argnames=("policy_act", "get_entropy", "entropy_loss_scale")) -def _update_policy(policy_act, - policy_state_dict, - sampled_states, - sampled_actions, - sampled_log_prob, - sampled_advantages, - ratio_clip, - get_entropy, - entropy_loss_scale, - alpha): +def _update_policy( + policy_act, + policy_state_dict, + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_advantages, + ratio_clip, + get_entropy, + entropy_loss_scale, + alpha, +): # compute policy loss def _policy_loss(params): - _, next_log_prob, outputs = policy_act({"states": sampled_states, "taken_actions": sampled_actions, "alpha": alpha}, "policy", params) + _, next_log_prob, outputs = policy_act( + {"states": sampled_states, "taken_actions": sampled_actions, "alpha": alpha}, "policy", params + ) # compute approximate KL divergence ratio = next_log_prob - sampled_log_prob @@ -165,20 +179,25 @@ def _policy_loss(params): return -jnp.minimum(surrogate, surrogate_clipped).mean(), (entropy_loss, kl_divergence, outputs["stddev"]) - (policy_loss, (entropy_loss, kl_divergence, stddev)), grad = jax.value_and_grad(_policy_loss, has_aux=True)(policy_state_dict.params) + (policy_loss, (entropy_loss, kl_divergence, stddev)), grad = jax.value_and_grad(_policy_loss, has_aux=True)( + policy_state_dict.params + ) return grad, policy_loss, entropy_loss, kl_divergence, stddev + @functools.partial(jax.jit, static_argnames=("value_act", "clip_predicted_values")) -def _update_value(value_act, - value_state_dict, - sampled_states, - sampled_values, - sampled_returns, - value_loss_scale, - clip_predicted_values, - value_clip, - alpha): +def _update_value( + value_act, + value_state_dict, + sampled_states, + sampled_values, + sampled_returns, + value_loss_scale, + clip_predicted_values, + value_clip, + alpha, +): # compute value loss def _value_loss(params): predicted_values, _, _ = value_act({"states": sampled_states, "alpha": alpha}, "value", params) @@ -192,13 +211,15 @@ def _value_loss(params): class RPO(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, jax.Device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, jax.Device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Robust Policy Optimization (RPO) https://arxiv.org/abs/2212.07536 @@ -224,12 +245,14 @@ def __init__(self, # _cfg = copy.deepcopy(PPO_DEFAULT_CONFIG) # TODO: TypeError: cannot pickle 'jax.Device' object _cfg = RPO_DEFAULT_CONFIG _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -287,13 +310,21 @@ def __init__(self, if self._learning_rate_scheduler is not None: if self._learning_rate_scheduler == KLAdaptiveLR: scale = False - self.scheduler = self._learning_rate_scheduler(self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"] + ) else: - self._learning_rate = self._learning_rate_scheduler(self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"]) + self._learning_rate = self._learning_rate_scheduler( + self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"] + ) # optimizer with jax.default_device(self.device): - self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) - self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) + self.policy_optimizer = Adam( + model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale + ) + self.value_optimizer = Adam( + model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale + ) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer self.checkpoint_modules["value_optimizer"] = self.value_optimizer @@ -312,8 +343,7 @@ def __init__(self, self._value_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -359,7 +389,9 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") # sample stochastic actions - actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states), "alpha": self._alpha}, role="policy") + actions, log_prob, outputs = self.policy.act( + {"states": self._state_preprocessor(states), "alpha": self._alpha}, role="policy" + ) if not self._jax: # numpy backend actions = jax.device_get(actions) log_prob = jax.device_get(log_prob) @@ -368,16 +400,18 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in return actions, log_prob, outputs - def record_transition(self, - states: Union[np.ndarray, jax.Array], - actions: Union[np.ndarray, jax.Array], - rewards: Union[np.ndarray, jax.Array], - next_states: Union[np.ndarray, jax.Array], - terminated: Union[np.ndarray, jax.Array], - truncated: Union[np.ndarray, jax.Array], - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Union[np.ndarray, jax.Array], + actions: Union[np.ndarray, jax.Array], + rewards: Union[np.ndarray, jax.Array], + next_states: Union[np.ndarray, jax.Array], + terminated: Union[np.ndarray, jax.Array], + truncated: Union[np.ndarray, jax.Array], + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -399,7 +433,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: self._current_next_states = next_states @@ -409,7 +445,9 @@ def record_transition(self, rewards = self._rewards_shaper(rewards, timestep, timesteps) # compute values - values, _, _ = self.value.act({"states": self._state_preprocessor(states), "alpha": self._alpha}, role="value") + values, _, _ = self.value.act( + {"states": self._state_preprocessor(states), "alpha": self._alpha}, role="value" + ) if not self._jax: # numpy backend values = jax.device_get(values) values = self._value_preprocessor(values, inverse=True) @@ -419,11 +457,27 @@ def record_transition(self, rewards += self._discount_factor * values * truncated # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -462,7 +516,9 @@ def _update(self, timestep: int, timesteps: int) -> None: """ # compute returns and advantages self.value.training = False - last_values, _, _ = self.value.act({"states": self._state_preprocessor(self._current_next_states), "alpha": self._alpha}, role="value") # TODO: .float() + last_values, _, _ = self.value.act( + {"states": self._state_preprocessor(self._current_next_states), "alpha": self._alpha}, role="value" + ) # TODO: .float() self.value.training = True if not self._jax: # numpy backend last_values = jax.device_get(last_values) @@ -470,19 +526,23 @@ def _update(self, timestep: int, timesteps: int) -> None: values = self.memory.get_tensor_by_name("values") if self._jax: - returns, advantages = _compute_gae(rewards=self.memory.get_tensor_by_name("rewards"), - dones=self.memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor, - lambda_coefficient=self._lambda) + returns, advantages = _compute_gae( + rewards=self.memory.get_tensor_by_name("rewards"), + dones=self.memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor, + lambda_coefficient=self._lambda, + ) else: - returns, advantages = compute_gae(rewards=self.memory.get_tensor_by_name("rewards"), - dones=self.memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor, - lambda_coefficient=self._lambda) + returns, advantages = compute_gae( + rewards=self.memory.get_tensor_by_name("rewards"), + dones=self.memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor, + lambda_coefficient=self._lambda, + ) self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True)) self.memory.set_tensor_by_name("returns", self._value_preprocessor(returns, train=True)) @@ -500,21 +560,30 @@ def _update(self, timestep: int, timesteps: int) -> None: kl_divergences = [] # mini-batches loop - for sampled_states, sampled_actions, sampled_log_prob, sampled_values, sampled_returns, sampled_advantages in sampled_batches: + for ( + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_values, + sampled_returns, + sampled_advantages, + ) in sampled_batches: sampled_states = self._state_preprocessor(sampled_states, train=not epoch) # compute policy loss - grad, policy_loss, entropy_loss, kl_divergence, stddev = _update_policy(self.policy.act, - self.policy.state_dict, - sampled_states, - sampled_actions, - sampled_log_prob, - sampled_advantages, - self._ratio_clip, - self.policy.get_entropy, - self._entropy_loss_scale, - self._alpha) + grad, policy_loss, entropy_loss, kl_divergence, stddev = _update_policy( + self.policy.act, + self.policy.state_dict, + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_advantages, + self._ratio_clip, + self.policy.get_entropy, + self._entropy_loss_scale, + self._alpha, + ) kl_divergences.append(kl_divergence.item()) @@ -525,23 +594,29 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (policy) if config.jax.is_distributed: grad = self.policy.reduce_parameters(grad) - self.policy_optimizer = self.policy_optimizer.step(grad, self.policy, self.scheduler._lr if self.scheduler else None) + self.policy_optimizer = self.policy_optimizer.step( + grad, self.policy, self.scheduler._lr if self.scheduler else None + ) # compute value loss - grad, value_loss = _update_value(self.value.act, - self.value.state_dict, - sampled_states, - sampled_values, - sampled_returns, - self._value_loss_scale, - self._clip_predicted_values, - self._value_clip, - self._alpha) + grad, value_loss = _update_value( + self.value.act, + self.value.state_dict, + sampled_states, + sampled_values, + sampled_returns, + self._value_loss_scale, + self._clip_predicted_values, + self._value_clip, + self._alpha, + ) # optimization step (value) if config.jax.is_distributed: grad = self.value.reduce_parameters(grad) - self.value_optimizer = self.value_optimizer.step(grad, self.value, self.scheduler._lr if self.scheduler else None) + self.value_optimizer = self.value_optimizer.step( + grad, self.value, self.scheduler._lr if self.scheduler else None + ) # update cumulative losses cumulative_policy_loss += policy_loss.item() @@ -555,7 +630,7 @@ def _update(self, timestep: int, timesteps: int) -> None: kl = np.mean(kl_divergences) # reduce (collect from all workers/processes) KL in distributed runs if config.jax.is_distributed: - kl = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(kl.reshape(1)).item() + kl = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(kl.reshape(1)).item() kl /= config.jax.world_size self.scheduler.step(kl) @@ -563,7 +638,9 @@ def _update(self, timestep: int, timesteps: int) -> None: self.track_data("Loss / Policy loss", cumulative_policy_loss / (self._learning_epochs * self._mini_batches)) self.track_data("Loss / Value loss", cumulative_value_loss / (self._learning_epochs * self._mini_batches)) if self._entropy_loss_scale: - self.track_data("Loss / Entropy loss", cumulative_entropy_loss / (self._learning_epochs * self._mini_batches)) + self.track_data( + "Loss / Entropy loss", cumulative_entropy_loss / (self._learning_epochs * self._mini_batches) + ) self.track_data("Policy / Standard deviation", stddev.mean().item()) diff --git a/skrl/agents/jax/sac/sac.py b/skrl/agents/jax/sac/sac.py index 6bdc5e05..365aff31 100644 --- a/skrl/agents/jax/sac/sac.py +++ b/skrl/agents/jax/sac/sac.py @@ -63,19 +63,21 @@ # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @functools.partial(jax.jit, static_argnames=("critic_1_act", "critic_2_act")) -def _update_critic(critic_1_act, - critic_1_state_dict, - critic_2_act, - critic_2_state_dict, - target_q1_values: jax.Array, - target_q2_values: jax.Array, - entropy_coefficient, - next_log_prob, - sampled_states: Union[np.ndarray, jax.Array], - sampled_actions: Union[np.ndarray, jax.Array], - sampled_rewards: Union[np.ndarray, jax.Array], - sampled_dones: Union[np.ndarray, jax.Array], - discount_factor: float): +def _update_critic( + critic_1_act, + critic_1_state_dict, + critic_2_act, + critic_2_state_dict, + target_q1_values: jax.Array, + target_q2_values: jax.Array, + entropy_coefficient, + next_log_prob, + sampled_states: Union[np.ndarray, jax.Array], + sampled_actions: Union[np.ndarray, jax.Array], + sampled_rewards: Union[np.ndarray, jax.Array], + sampled_dones: Union[np.ndarray, jax.Array], + discount_factor: float, +): # compute target values target_q_values = jnp.minimum(target_q1_values, target_q2_values) - entropy_coefficient * next_log_prob target_values = sampled_rewards + discount_factor * jnp.logical_not(sampled_dones) * target_q_values @@ -86,31 +88,45 @@ def _critic_loss(params, critic_act, role): critic_loss = ((critic_values - target_values) ** 2).mean() return critic_loss, critic_values - (critic_1_loss, critic_1_values), grad = jax.value_and_grad(_critic_loss, has_aux=True)(critic_1_state_dict.params, critic_1_act, "critic_1") - (critic_2_loss, critic_2_values), grad = jax.value_and_grad(_critic_loss, has_aux=True)(critic_2_state_dict.params, critic_2_act, "critic_2") + (critic_1_loss, critic_1_values), grad = jax.value_and_grad(_critic_loss, has_aux=True)( + critic_1_state_dict.params, critic_1_act, "critic_1" + ) + (critic_2_loss, critic_2_values), grad = jax.value_and_grad(_critic_loss, has_aux=True)( + critic_2_state_dict.params, critic_2_act, "critic_2" + ) return grad, (critic_1_loss + critic_2_loss) / 2, critic_1_values, critic_2_values, target_values + @functools.partial(jax.jit, static_argnames=("policy_act", "critic_1_act", "critic_2_act")) -def _update_policy(policy_act, - critic_1_act, - critic_2_act, - policy_state_dict, - critic_1_state_dict, - critic_2_state_dict, - entropy_coefficient, - sampled_states): +def _update_policy( + policy_act, + critic_1_act, + critic_2_act, + policy_state_dict, + critic_1_state_dict, + critic_2_state_dict, + entropy_coefficient, + sampled_states, +): # compute policy (actor) loss def _policy_loss(policy_params, critic_1_params, critic_2_params): actions, log_prob, _ = policy_act({"states": sampled_states}, "policy", policy_params) - critic_1_values, _, _ = critic_1_act({"states": sampled_states, "taken_actions": actions}, "critic_1", critic_1_params) - critic_2_values, _, _ = critic_2_act({"states": sampled_states, "taken_actions": actions}, "critic_2", critic_2_params) + critic_1_values, _, _ = critic_1_act( + {"states": sampled_states, "taken_actions": actions}, "critic_1", critic_1_params + ) + critic_2_values, _, _ = critic_2_act( + {"states": sampled_states, "taken_actions": actions}, "critic_2", critic_2_params + ) return (entropy_coefficient * log_prob - jnp.minimum(critic_1_values, critic_2_values)).mean(), log_prob - (policy_loss, log_prob), grad = jax.value_and_grad(_policy_loss, has_aux=True)(policy_state_dict.params, critic_1_state_dict.params, critic_2_state_dict.params) + (policy_loss, log_prob), grad = jax.value_and_grad(_policy_loss, has_aux=True)( + policy_state_dict.params, critic_1_state_dict.params, critic_2_state_dict.params + ) return grad, policy_loss, log_prob + @jax.jit def _update_entropy(log_entropy_coefficient_state_dict, target_entropy, log_prob): # compute entropy loss @@ -123,13 +139,15 @@ def _entropy_loss(params): class SAC(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, jax.Device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, jax.Device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Soft Actor-Critic (SAC) https://arxiv.org/abs/1801.01290 @@ -155,12 +173,14 @@ def __init__(self, # _cfg = copy.deepcopy(SAC_DEFAULT_CONFIG) # TODO: TypeError: cannot pickle 'jax.Device' object _cfg = SAC_DEFAULT_CONFIG _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -225,7 +245,10 @@ class _LogEntropyCoefficient: def __init__(self, entropy_coefficient: float) -> None: class StateDict(flax.struct.PyTreeNode): params: flax.core.FrozenDict[str, Any] = flax.struct.field(pytree_node=True) - self.state_dict = StateDict(flax.core.FrozenDict({"params": jnp.array([jnp.log(entropy_coefficient)])})) + + self.state_dict = StateDict( + flax.core.FrozenDict({"params": jnp.array([jnp.log(entropy_coefficient)])}) + ) @property def value(self): @@ -240,13 +263,25 @@ def value(self): # set up optimizers and learning rate schedulers if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: with jax.default_device(self.device): - self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip) - self.critic_1_optimizer = Adam(model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip) - self.critic_2_optimizer = Adam(model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip) + self.policy_optimizer = Adam( + model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip + ) + self.critic_1_optimizer = Adam( + model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip + ) + self.critic_2_optimizer = Adam( + model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip + ) if self._learning_rate_scheduler is not None: - self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) - self.critic_1_scheduler = self._learning_rate_scheduler(self.critic_1_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) - self.critic_2_scheduler = self._learning_rate_scheduler(self.critic_2_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.policy_scheduler = self._learning_rate_scheduler( + self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) + self.critic_1_scheduler = self._learning_rate_scheduler( + self.critic_1_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) + self.critic_2_scheduler = self._learning_rate_scheduler( + self.critic_2_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer self.checkpoint_modules["critic_1_optimizer"] = self.critic_1_optimizer @@ -270,8 +305,7 @@ def value(self): self._state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -319,16 +353,18 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in return actions, None, outputs - def record_transition(self, - states: Union[np.ndarray, jax.Array], - actions: Union[np.ndarray, jax.Array], - rewards: Union[np.ndarray, jax.Array], - next_states: Union[np.ndarray, jax.Array], - terminated: Union[np.ndarray, jax.Array], - truncated: Union[np.ndarray, jax.Array], - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Union[np.ndarray, jax.Array], + actions: Union[np.ndarray, jax.Array], + rewards: Union[np.ndarray, jax.Array], + next_states: Union[np.ndarray, jax.Array], + terminated: Union[np.ndarray, jax.Array], + truncated: Union[np.ndarray, jax.Array], + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -350,7 +386,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: # reward shaping @@ -358,11 +396,23 @@ def record_transition(self, rewards = self._rewards_shaper(rewards, timestep, timesteps) # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -403,8 +453,9 @@ def _update(self, timestep: int, timesteps: int) -> None: for gradient_step in range(self._gradient_steps): # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = self.memory.sample( + names=self._tensors_names, batch_size=self._batch_size + )[0] sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) @@ -412,23 +463,29 @@ def _update(self, timestep: int, timesteps: int) -> None: next_actions, next_log_prob, _ = self.policy.act({"states": sampled_next_states}, role="policy") # compute target values - target_q1_values, _, _ = self.target_critic_1.act({"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_1") - target_q2_values, _, _ = self.target_critic_2.act({"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_2") + target_q1_values, _, _ = self.target_critic_1.act( + {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_1" + ) + target_q2_values, _, _ = self.target_critic_2.act( + {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_2" + ) # compute critic loss - grad, critic_loss, critic_1_values, critic_2_values, target_values = _update_critic(self.critic_1.act, - self.critic_1.state_dict, - self.critic_2.act, - self.critic_2.state_dict, - target_q1_values, - target_q2_values, - self._entropy_coefficient, - next_log_prob, - sampled_states, - sampled_actions, - sampled_rewards, - sampled_dones, - self._discount_factor) + grad, critic_loss, critic_1_values, critic_2_values, target_values = _update_critic( + self.critic_1.act, + self.critic_1.state_dict, + self.critic_2.act, + self.critic_2.state_dict, + target_q1_values, + target_q2_values, + self._entropy_coefficient, + next_log_prob, + sampled_states, + sampled_actions, + sampled_rewards, + sampled_dones, + self._discount_factor, + ) # optimization step (critic) if config.jax.is_distributed: @@ -437,14 +494,16 @@ def _update(self, timestep: int, timesteps: int) -> None: self.critic_2_optimizer = self.critic_2_optimizer.step(grad, self.critic_2) # compute policy (actor) loss - grad, policy_loss, log_prob = _update_policy(self.policy.act, - self.critic_1.act, - self.critic_2.act, - self.policy.state_dict, - self.critic_1.state_dict, - self.critic_2.state_dict, - self._entropy_coefficient, - sampled_states) + grad, policy_loss, log_prob = _update_policy( + self.policy.act, + self.critic_1.act, + self.critic_2.act, + self.policy.state_dict, + self.critic_1.state_dict, + self.critic_2.state_dict, + self._entropy_coefficient, + sampled_states, + ) # optimization step (policy) if config.jax.is_distributed: @@ -454,9 +513,9 @@ def _update(self, timestep: int, timesteps: int) -> None: # entropy learning if self._learn_entropy: # compute entropy loss - grad, entropy_loss = _update_entropy(self.log_entropy_coefficient.state_dict, - self._target_entropy, - log_prob) + grad, entropy_loss = _update_entropy( + self.log_entropy_coefficient.state_dict, self._target_entropy, log_prob + ) # optimization step (entropy) self.entropy_optimizer = self.entropy_optimizer.step(grad, self.log_entropy_coefficient) diff --git a/skrl/agents/jax/td3/td3.py b/skrl/agents/jax/td3/td3.py index b6057d32..f784d95d 100644 --- a/skrl/agents/jax/td3/td3.py +++ b/skrl/agents/jax/td3/td3.py @@ -68,35 +68,39 @@ # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @jax.jit -def _apply_exploration_noise(actions: jax.Array, - noises: jax.Array, - clip_actions_min: jax.Array, - clip_actions_max: jax.Array, - scale: float) -> jax.Array: +def _apply_exploration_noise( + actions: jax.Array, noises: jax.Array, clip_actions_min: jax.Array, clip_actions_max: jax.Array, scale: float +) -> jax.Array: noises = noises.at[:].multiply(scale) return jnp.clip(actions + noises, a_min=clip_actions_min, a_max=clip_actions_max), noises + @jax.jit -def _apply_smooth_regularization_noise(actions: jax.Array, - noises: jax.Array, - clip_actions_min: jax.Array, - clip_actions_max: jax.Array, - smooth_regularization_clip: float) -> jax.Array: +def _apply_smooth_regularization_noise( + actions: jax.Array, + noises: jax.Array, + clip_actions_min: jax.Array, + clip_actions_max: jax.Array, + smooth_regularization_clip: float, +) -> jax.Array: noises = jnp.clip(noises, a_min=-smooth_regularization_clip, a_max=smooth_regularization_clip) return jnp.clip(actions + noises, a_min=clip_actions_min, a_max=clip_actions_max) + @functools.partial(jax.jit, static_argnames=("critic_1_act", "critic_2_act")) -def _update_critic(critic_1_act, - critic_1_state_dict, - critic_2_act, - critic_2_state_dict, - target_q1_values: jax.Array, - target_q2_values: jax.Array, - sampled_states: Union[np.ndarray, jax.Array], - sampled_actions: Union[np.ndarray, jax.Array], - sampled_rewards: Union[np.ndarray, jax.Array], - sampled_dones: Union[np.ndarray, jax.Array], - discount_factor: float): +def _update_critic( + critic_1_act, + critic_1_state_dict, + critic_2_act, + critic_2_state_dict, + target_q1_values: jax.Array, + target_q2_values: jax.Array, + sampled_states: Union[np.ndarray, jax.Array], + sampled_actions: Union[np.ndarray, jax.Array], + sampled_rewards: Union[np.ndarray, jax.Array], + sampled_dones: Union[np.ndarray, jax.Array], + discount_factor: float, +): # compute target values target_q_values = jnp.minimum(target_q1_values, target_q2_values) target_values = sampled_rewards + discount_factor * jnp.logical_not(sampled_dones) * target_q_values @@ -107,36 +111,43 @@ def _critic_loss(params, critic_act, role): critic_loss = ((critic_values - target_values) ** 2).mean() return critic_loss, critic_values - (critic_1_loss, critic_1_values), grad = jax.value_and_grad(_critic_loss, has_aux=True)(critic_1_state_dict.params, critic_1_act, "critic_1") - (critic_2_loss, critic_2_values), grad = jax.value_and_grad(_critic_loss, has_aux=True)(critic_2_state_dict.params, critic_2_act, "critic_2") + (critic_1_loss, critic_1_values), grad = jax.value_and_grad(_critic_loss, has_aux=True)( + critic_1_state_dict.params, critic_1_act, "critic_1" + ) + (critic_2_loss, critic_2_values), grad = jax.value_and_grad(_critic_loss, has_aux=True)( + critic_2_state_dict.params, critic_2_act, "critic_2" + ) return grad, critic_1_loss + critic_2_loss, critic_1_values, critic_2_values, target_values + @functools.partial(jax.jit, static_argnames=("policy_act", "critic_1_act")) -def _update_policy(policy_act, - critic_1_act, - policy_state_dict, - critic_1_state_dict, - sampled_states): +def _update_policy(policy_act, critic_1_act, policy_state_dict, critic_1_state_dict, sampled_states): # compute policy (actor) loss def _policy_loss(policy_params, critic_1_params): actions, _, _ = policy_act({"states": sampled_states}, "policy", policy_params) - critic_values, _, _ = critic_1_act({"states": sampled_states, "taken_actions": actions}, "critic_1", critic_1_params) + critic_values, _, _ = critic_1_act( + {"states": sampled_states, "taken_actions": actions}, "critic_1", critic_1_params + ) return -critic_values.mean() - policy_loss, grad = jax.value_and_grad(_policy_loss, has_aux=False)(policy_state_dict.params, critic_1_state_dict.params) + policy_loss, grad = jax.value_and_grad(_policy_loss, has_aux=False)( + policy_state_dict.params, critic_1_state_dict.params + ) return grad, policy_loss class TD3(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, jax.Device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, jax.Device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Twin Delayed DDPG (TD3) https://arxiv.org/abs/1802.09477 @@ -162,12 +173,14 @@ def __init__(self, # _cfg = copy.deepcopy(TD3_DEFAULT_CONFIG) # TODO: TypeError: cannot pickle 'jax.Device' object _cfg = TD3_DEFAULT_CONFIG _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -231,13 +244,25 @@ def __init__(self, # set up optimizers and learning rate schedulers if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: with jax.default_device(self.device): - self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip) - self.critic_1_optimizer = Adam(model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip) - self.critic_2_optimizer = Adam(model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip) + self.policy_optimizer = Adam( + model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip + ) + self.critic_1_optimizer = Adam( + model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip + ) + self.critic_2_optimizer = Adam( + model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip + ) if self._learning_rate_scheduler is not None: - self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) - self.critic_1_scheduler = self._learning_rate_scheduler(self.critic_1_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) - self.critic_2_scheduler = self._learning_rate_scheduler(self.critic_2_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.policy_scheduler = self._learning_rate_scheduler( + self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) + self.critic_1_scheduler = self._learning_rate_scheduler( + self.critic_1_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) + self.critic_2_scheduler = self._learning_rate_scheduler( + self.critic_2_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer self.checkpoint_modules["critic_1_optimizer"] = self.critic_1_optimizer @@ -263,8 +288,7 @@ def __init__(self, self._state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -331,13 +355,15 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in # apply exploration noise if timestep <= self._exploration_timesteps: - scale = (1 - timestep / self._exploration_timesteps) \ - * (self._exploration_initial_scale - self._exploration_final_scale) \ - + self._exploration_final_scale + scale = (1 - timestep / self._exploration_timesteps) * ( + self._exploration_initial_scale - self._exploration_final_scale + ) + self._exploration_final_scale # modify actions if self._jax: - actions, noises = _apply_exploration_noise(actions, noises, self.clip_actions_min, self.clip_actions_max, scale) + actions, noises = _apply_exploration_noise( + actions, noises, self.clip_actions_min, self.clip_actions_max, scale + ) else: noises *= scale actions = np.clip(actions + noises, a_min=self.clip_actions_min, a_max=self.clip_actions_max) @@ -355,16 +381,18 @@ def act(self, states: Union[np.ndarray, jax.Array], timestep: int, timesteps: in return actions, None, outputs - def record_transition(self, - states: Union[np.ndarray, jax.Array], - actions: Union[np.ndarray, jax.Array], - rewards: Union[np.ndarray, jax.Array], - next_states: Union[np.ndarray, jax.Array], - terminated: Union[np.ndarray, jax.Array], - truncated: Union[np.ndarray, jax.Array], - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Union[np.ndarray, jax.Array], + actions: Union[np.ndarray, jax.Array], + rewards: Union[np.ndarray, jax.Array], + next_states: Union[np.ndarray, jax.Array], + terminated: Union[np.ndarray, jax.Array], + truncated: Union[np.ndarray, jax.Array], + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -386,7 +414,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: # reward shaping @@ -394,11 +424,23 @@ def record_transition(self, rewards = self._rewards_shaper(rewards, timestep, timesteps) # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -439,8 +481,9 @@ def _update(self, timestep: int, timesteps: int) -> None: for gradient_step in range(self._gradient_steps): # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = self.memory.sample( + names=self._tensors_names, batch_size=self._batch_size + )[0] sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) @@ -450,27 +493,43 @@ def _update(self, timestep: int, timesteps: int) -> None: if self._smooth_regularization_noise is not None: noises = self._smooth_regularization_noise.sample(next_actions.shape) if self._jax: - next_actions = _apply_smooth_regularization_noise(next_actions, noises, self.clip_actions_min, self.clip_actions_max, self._smooth_regularization_clip) + next_actions = _apply_smooth_regularization_noise( + next_actions, + noises, + self.clip_actions_min, + self.clip_actions_max, + self._smooth_regularization_clip, + ) else: - noises = np.clip(noises, a_min=-self._smooth_regularization_clip, a_max=self._smooth_regularization_clip) - next_actions = np.clip(next_actions + noises, a_min=self.clip_actions_min, a_max=self.clip_actions_max) + noises = np.clip( + noises, a_min=-self._smooth_regularization_clip, a_max=self._smooth_regularization_clip + ) + next_actions = np.clip( + next_actions + noises, a_min=self.clip_actions_min, a_max=self.clip_actions_max + ) # compute target values - target_q1_values, _, _ = self.target_critic_1.act({"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_1") - target_q2_values, _, _ = self.target_critic_2.act({"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_2") + target_q1_values, _, _ = self.target_critic_1.act( + {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_1" + ) + target_q2_values, _, _ = self.target_critic_2.act( + {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_2" + ) # compute critic loss - grad, critic_loss, critic_1_values, critic_2_values, target_values = _update_critic(self.critic_1.act, - self.critic_1.state_dict, - self.critic_2.act, - self.critic_2.state_dict, - target_q1_values, - target_q2_values, - sampled_states, - sampled_actions, - sampled_rewards, - sampled_dones, - self._discount_factor) + grad, critic_loss, critic_1_values, critic_2_values, target_values = _update_critic( + self.critic_1.act, + self.critic_1.state_dict, + self.critic_2.act, + self.critic_2.state_dict, + target_q1_values, + target_q2_values, + sampled_states, + sampled_actions, + sampled_rewards, + sampled_dones, + self._discount_factor, + ) # optimization step (critic) if config.jax.is_distributed: @@ -483,11 +542,9 @@ def _update(self, timestep: int, timesteps: int) -> None: if not self._critic_update_counter % self._policy_delay: # compute policy (actor) loss - grad, policy_loss = _update_policy(self.policy.act, - self.critic_1.act, - self.policy.state_dict, - self.critic_1.state_dict, - sampled_states) + grad, policy_loss = _update_policy( + self.policy.act, self.critic_1.act, self.policy.state_dict, self.critic_1.state_dict, sampled_states + ) # optimization step (policy) if config.jax.is_distributed: diff --git a/skrl/agents/torch/a2c/a2c.py b/skrl/agents/torch/a2c/a2c.py index 12df988d..cb08cca8 100644 --- a/skrl/agents/torch/a2c/a2c.py +++ b/skrl/agents/torch/a2c/a2c.py @@ -60,13 +60,15 @@ class A2C(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Advantage Actor Critic (A2C) https://arxiv.org/abs/1602.01783 @@ -91,12 +93,14 @@ def __init__(self, """ _cfg = copy.deepcopy(A2C_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -143,10 +147,13 @@ def __init__(self, if self.policy is self.value: self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._learning_rate) else: - self.optimizer = torch.optim.Adam(itertools.chain(self.policy.parameters(), self.value.parameters()), - lr=self._learning_rate) + self.optimizer = torch.optim.Adam( + itertools.chain(self.policy.parameters(), self.value.parameters()), lr=self._learning_rate + ) if self._learning_rate_scheduler is not None: - self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["optimizer"] = self.optimizer @@ -164,8 +171,7 @@ def __init__(self, self._value_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -210,16 +216,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, log_prob, outputs - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -241,7 +249,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: self._current_next_states = next_states @@ -259,11 +269,27 @@ def record_transition(self, rewards += self._discount_factor * values * truncated # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -300,12 +326,15 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - def compute_gae(rewards: torch.Tensor, - dones: torch.Tensor, - values: torch.Tensor, - next_values: torch.Tensor, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> torch.Tensor: + + def compute_gae( + rewards: torch.Tensor, + dones: torch.Tensor, + values: torch.Tensor, + next_values: torch.Tensor, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, + ) -> torch.Tensor: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -332,7 +361,11 @@ def compute_gae(rewards: torch.Tensor, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else last_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] + - values[i] + + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -344,17 +377,21 @@ def compute_gae(rewards: torch.Tensor, # compute returns and advantages with torch.no_grad(): self.value.train(False) - last_values, _, _ = self.value.act({"states": self._state_preprocessor(self._current_next_states.float())}, role="value") + last_values, _, _ = self.value.act( + {"states": self._state_preprocessor(self._current_next_states.float())}, role="value" + ) self.value.train(True) last_values = self._value_preprocessor(last_values, inverse=True) values = self.memory.get_tensor_by_name("values") - returns, advantages = compute_gae(rewards=self.memory.get_tensor_by_name("rewards"), - dones=self.memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor, - lambda_coefficient=self._lambda) + returns, advantages = compute_gae( + rewards=self.memory.get_tensor_by_name("rewards"), + dones=self.memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor, + lambda_coefficient=self._lambda, + ) self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True)) self.memory.set_tensor_by_name("returns", self._value_preprocessor(returns, train=True)) @@ -374,7 +411,9 @@ def compute_gae(rewards: torch.Tensor, sampled_states = self._state_preprocessor(sampled_states, train=True) - _, next_log_prob, _ = self.policy.act({"states": sampled_states, "taken_actions": sampled_actions}, role="policy") + _, next_log_prob, _ = self.policy.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="policy" + ) # compute approximate KL divergence for KLAdaptive learning rate scheduler if self._learning_rate_scheduler: @@ -409,7 +448,9 @@ def compute_gae(rewards: torch.Tensor, if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) else: - nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip) + nn.utils.clip_grad_norm_( + itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip + ) self.optimizer.step() # update cumulative losses diff --git a/skrl/agents/torch/a2c/a2c_rnn.py b/skrl/agents/torch/a2c/a2c_rnn.py index c087671a..dd93e0dd 100644 --- a/skrl/agents/torch/a2c/a2c_rnn.py +++ b/skrl/agents/torch/a2c/a2c_rnn.py @@ -60,13 +60,15 @@ class A2C_RNN(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Advantage Actor Critic (A2C) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.) https://arxiv.org/abs/1602.01783 @@ -91,12 +93,14 @@ def __init__(self, """ _cfg = copy.deepcopy(A2C_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -143,10 +147,13 @@ def __init__(self, if self.policy is self.value: self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._learning_rate) else: - self.optimizer = torch.optim.Adam(itertools.chain(self.policy.parameters(), self.value.parameters()), - lr=self._learning_rate) + self.optimizer = torch.optim.Adam( + itertools.chain(self.policy.parameters(), self.value.parameters()), lr=self._learning_rate + ) if self._learning_rate_scheduler is not None: - self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["optimizer"] = self.optimizer @@ -164,8 +171,7 @@ def __init__(self, self._value_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -194,7 +200,9 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self._rnn = True # create tensors in memory if self.memory is not None: - self.memory.create_tensor(name=f"rnn_policy_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True) + self.memory.create_tensor( + name=f"rnn_policy_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True + ) self._rnn_tensors_names.append(f"rnn_policy_{i}") # default RNN states self._rnn_initial_states["policy"].append(torch.zeros(size, dtype=torch.float32, device=self.device)) @@ -208,7 +216,9 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self._rnn = True # create tensors in memory if self.memory is not None: - self.memory.create_tensor(name=f"rnn_value_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True) + self.memory.create_tensor( + name=f"rnn_value_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True + ) self._rnn_tensors_names.append(f"rnn_value_{i}") # default RNN states self._rnn_initial_states["value"].append(torch.zeros(size, dtype=torch.float32, device=self.device)) @@ -246,16 +256,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, log_prob, outputs - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -277,7 +289,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: self._current_next_states = next_states @@ -298,20 +312,44 @@ def record_transition(self, # package RNN states rnn_states = {} if self._rnn: - rnn_states.update({f"rnn_policy_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["policy"])}) + rnn_states.update( + {f"rnn_policy_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["policy"])} + ) if self.policy is not self.value: - rnn_states.update({f"rnn_value_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["value"])}) + rnn_states.update( + {f"rnn_value_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["value"])} + ) # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values, **rnn_states) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + **rnn_states, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values, **rnn_states) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + **rnn_states, + ) # update RNN states if self._rnn: - self._rnn_final_states["value"] = self._rnn_final_states["policy"] if self.policy is self.value else outputs.get("rnn", []) + self._rnn_final_states["value"] = ( + self._rnn_final_states["policy"] if self.policy is self.value else outputs.get("rnn", []) + ) # reset states if the episodes have ended finished_episodes = terminated.nonzero(as_tuple=False) @@ -359,12 +397,15 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - def compute_gae(rewards: torch.Tensor, - dones: torch.Tensor, - values: torch.Tensor, - next_values: torch.Tensor, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> torch.Tensor: + + def compute_gae( + rewards: torch.Tensor, + dones: torch.Tensor, + values: torch.Tensor, + next_values: torch.Tensor, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, + ) -> torch.Tensor: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -391,7 +432,11 @@ def compute_gae(rewards: torch.Tensor, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else last_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] + - values[i] + + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -404,28 +449,38 @@ def compute_gae(rewards: torch.Tensor, with torch.no_grad(): self.value.train(False) rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {} - last_values, _, _ = self.value.act({"states": self._state_preprocessor(self._current_next_states.float()), **rnn}, role="value") + last_values, _, _ = self.value.act( + {"states": self._state_preprocessor(self._current_next_states.float()), **rnn}, role="value" + ) self.value.train(True) last_values = self._value_preprocessor(last_values, inverse=True) values = self.memory.get_tensor_by_name("values") - returns, advantages = compute_gae(rewards=self.memory.get_tensor_by_name("rewards"), - dones=self.memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor, - lambda_coefficient=self._lambda) + returns, advantages = compute_gae( + rewards=self.memory.get_tensor_by_name("rewards"), + dones=self.memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor, + lambda_coefficient=self._lambda, + ) self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True)) self.memory.set_tensor_by_name("returns", self._value_preprocessor(returns, train=True)) self.memory.set_tensor_by_name("advantages", advantages) # sample mini-batches from memory - sampled_batches = self.memory.sample_all(names=self._tensors_names, mini_batches=self._mini_batches, sequence_length=self._rnn_sequence_length) + sampled_batches = self.memory.sample_all( + names=self._tensors_names, mini_batches=self._mini_batches, sequence_length=self._rnn_sequence_length + ) rnn_policy, rnn_value = {}, {} if self._rnn: - sampled_rnn_batches = self.memory.sample_all(names=self._rnn_tensors_names, mini_batches=self._mini_batches, sequence_length=self._rnn_sequence_length) + sampled_rnn_batches = self.memory.sample_all( + names=self._rnn_tensors_names, + mini_batches=self._mini_batches, + sequence_length=self._rnn_sequence_length, + ) cumulative_policy_loss = 0 cumulative_entropy_loss = 0 @@ -434,19 +489,45 @@ def compute_gae(rewards: torch.Tensor, kl_divergences = [] # mini-batches loop - for i, (sampled_states, sampled_actions, sampled_dones, sampled_log_prob, sampled_returns, sampled_advantages) in enumerate(sampled_batches): + for i, ( + sampled_states, + sampled_actions, + sampled_dones, + sampled_log_prob, + sampled_returns, + sampled_advantages, + ) in enumerate(sampled_batches): if self._rnn: if self.policy is self.value: - rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn_batches[i]], "terminated": sampled_dones} + rnn_policy = { + "rnn": [s.transpose(0, 1) for s in sampled_rnn_batches[i]], + "terminated": sampled_dones, + } rnn_value = rnn_policy else: - rnn_policy = {"rnn": [s.transpose(0, 1) for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) if "policy" in n], "terminated": sampled_dones} - rnn_value = {"rnn": [s.transpose(0, 1) for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) if "value" in n], "terminated": sampled_dones} + rnn_policy = { + "rnn": [ + s.transpose(0, 1) + for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) + if "policy" in n + ], + "terminated": sampled_dones, + } + rnn_value = { + "rnn": [ + s.transpose(0, 1) + for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) + if "value" in n + ], + "terminated": sampled_dones, + } sampled_states = self._state_preprocessor(sampled_states, train=True) - _, next_log_prob, _ = self.policy.act({"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="policy") + _, next_log_prob, _ = self.policy.act( + {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="policy" + ) # compute approximate KL divergence for KLAdaptive learning rate scheduler if isinstance(self.scheduler, KLAdaptiveLR): @@ -480,7 +561,9 @@ def compute_gae(rewards: torch.Tensor, if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) else: - nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip) + nn.utils.clip_grad_norm_( + itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip + ) self.optimizer.step() # update cumulative losses diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py index 1cc857ee..45e40659 100644 --- a/skrl/agents/torch/amp/amp.py +++ b/skrl/agents/torch/amp/amp.py @@ -77,18 +77,20 @@ class AMP(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None, - amp_observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - motion_dataset: Optional[Memory] = None, - reply_buffer: Optional[Memory] = None, - collect_reference_motions: Optional[Callable[[int], torch.Tensor]] = None, - collect_observation: Optional[Callable[[], torch.Tensor]] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + amp_observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + motion_dataset: Optional[Memory] = None, + reply_buffer: Optional[Memory] = None, + collect_reference_motions: Optional[Callable[[int], torch.Tensor]] = None, + collect_observation: Optional[Callable[[], torch.Tensor]] = None, + ) -> None: """Adversarial Motion Priors (AMP) https://arxiv.org/abs/2104.02180 @@ -126,12 +128,14 @@ def __init__(self, """ _cfg = copy.deepcopy(AMP_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) self.amp_observation_space = amp_observation_space self.motion_dataset = motion_dataset @@ -202,12 +206,14 @@ def __init__(self, # set up optimizer and learning rate scheduler if self.policy is not None and self.value is not None and self.discriminator is not None: - self.optimizer = torch.optim.Adam(itertools.chain(self.policy.parameters(), - self.value.parameters(), - self.discriminator.parameters()), - lr=self._learning_rate) + self.optimizer = torch.optim.Adam( + itertools.chain(self.policy.parameters(), self.value.parameters(), self.discriminator.parameters()), + lr=self._learning_rate, + ) if self._learning_rate_scheduler is not None: - self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["optimizer"] = self.optimizer @@ -231,8 +237,7 @@ def __init__(self, self._amp_state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -251,8 +256,19 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self.memory.create_tensor(name="amp_states", size=self.amp_observation_space, dtype=torch.float32) self.memory.create_tensor(name="next_values", size=1, dtype=torch.float32) - self.tensors_names = ["states", "actions", "rewards", "next_states", "terminated", \ - "log_prob", "values", "returns", "advantages", "amp_states", "next_values"] + self.tensors_names = [ + "states", + "actions", + "rewards", + "next_states", + "terminated", + "log_prob", + "values", + "returns", + "advantages", + "amp_states", + "next_values", + ] # create tensors for motion dataset and reply buffer if self.motion_dataset is not None: @@ -297,16 +313,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, log_prob, outputs - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -332,7 +350,9 @@ def record_transition(self, if self._current_states is not None: states = self._current_states - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: amp_states = infos["amp_obs"] @@ -352,13 +372,33 @@ def record_transition(self, with torch.no_grad(): next_values, _, _ = self.value.act({"states": self._state_preprocessor(next_states)}, role="value") next_values = self._value_preprocessor(next_values, inverse=True) - next_values *= infos['terminate'].view(-1, 1).logical_not() - - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, terminated=terminated, truncated=truncated, - log_prob=self._current_log_prob, values=values, amp_states=amp_states, next_values=next_values) + next_values *= infos["terminate"].view(-1, 1).logical_not() + + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + amp_states=amp_states, + next_values=next_values, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, terminated=terminated, truncated=truncated, - log_prob=self._current_log_prob, values=values, amp_states=amp_states, next_values=next_values) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + amp_states=amp_states, + next_values=next_values, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -396,12 +436,15 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - def compute_gae(rewards: torch.Tensor, - dones: torch.Tensor, - values: torch.Tensor, - next_values: torch.Tensor, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> torch.Tensor: + + def compute_gae( + rewards: torch.Tensor, + dones: torch.Tensor, + values: torch.Tensor, + next_values: torch.Tensor, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, + ) -> torch.Tensor: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -427,7 +470,11 @@ def compute_gae(rewards: torch.Tensor, # advantages computation for i in reversed(range(memory_size)): - advantage = rewards[i] - values[i] + discount_factor * (next_values[i] + lambda_coefficient * not_dones[i] * advantage) + advantage = ( + rewards[i] + - values[i] + + discount_factor * (next_values[i] + lambda_coefficient * not_dones[i] * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -444,21 +491,27 @@ def compute_gae(rewards: torch.Tensor, amp_states = self.memory.get_tensor_by_name("amp_states") with torch.no_grad(): - amp_logits, _, _ = self.discriminator.act({"states": self._amp_state_preprocessor(amp_states)}, role="discriminator") - style_reward = -torch.log(torch.maximum(1 - 1 / (1 + torch.exp(-amp_logits)), torch.tensor(0.0001, device=self.device))) + amp_logits, _, _ = self.discriminator.act( + {"states": self._amp_state_preprocessor(amp_states)}, role="discriminator" + ) + style_reward = -torch.log( + torch.maximum(1 - 1 / (1 + torch.exp(-amp_logits)), torch.tensor(0.0001, device=self.device)) + ) style_reward *= self._discriminator_reward_scale combined_rewards = self._task_reward_weight * rewards + self._style_reward_weight * style_reward # compute returns and advantages values = self.memory.get_tensor_by_name("values") - next_values=self.memory.get_tensor_by_name("next_values") - returns, advantages = compute_gae(rewards=combined_rewards, - dones=self.memory.get_tensor_by_name("terminated"), - values=values, - next_values=next_values, - discount_factor=self._discount_factor, - lambda_coefficient=self._lambda) + next_values = self.memory.get_tensor_by_name("next_values") + returns, advantages = compute_gae( + rewards=combined_rewards, + dones=self.memory.get_tensor_by_name("terminated"), + values=values, + next_values=next_values, + discount_factor=self._discount_factor, + lambda_coefficient=self._lambda, + ) self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True)) self.memory.set_tensor_by_name("returns", self._value_preprocessor(returns, train=True)) @@ -466,13 +519,15 @@ def compute_gae(rewards: torch.Tensor, # sample mini-batches from memory sampled_batches = self.memory.sample_all(names=self.tensors_names, mini_batches=self._mini_batches) - sampled_motion_batches = self.motion_dataset.sample(names=["states"], - batch_size=self.memory.memory_size * self.memory.num_envs, - mini_batches=self._mini_batches) + sampled_motion_batches = self.motion_dataset.sample( + names=["states"], batch_size=self.memory.memory_size * self.memory.num_envs, mini_batches=self._mini_batches + ) if len(self.reply_buffer): - sampled_replay_batches = self.reply_buffer.sample(names=["states"], - batch_size=self.memory.memory_size * self.memory.num_envs, - mini_batches=self._mini_batches) + sampled_replay_batches = self.reply_buffer.sample( + names=["states"], + batch_size=self.memory.memory_size * self.memory.num_envs, + mini_batches=self._mini_batches, + ) else: sampled_replay_batches = [[batches[self.tensors_names.index("amp_states")]] for batches in sampled_batches] @@ -485,13 +540,25 @@ def compute_gae(rewards: torch.Tensor, for epoch in range(self._learning_epochs): # mini-batches loop - for batch_index, (sampled_states, sampled_actions, _, _, _, \ - sampled_log_prob, sampled_values, sampled_returns, sampled_advantages, \ - sampled_amp_states, _) in enumerate(sampled_batches): + for batch_index, ( + sampled_states, + sampled_actions, + _, + _, + _, + sampled_log_prob, + sampled_values, + sampled_returns, + sampled_advantages, + sampled_amp_states, + _, + ) in enumerate(sampled_batches): sampled_states = self._state_preprocessor(sampled_states, train=True) - _, next_log_prob, _ = self.policy.act({"states": sampled_states, "taken_actions": sampled_actions}, role="policy") + _, next_log_prob, _ = self.policy.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="policy" + ) # compute entropy loss if self._entropy_loss_scale: @@ -502,7 +569,9 @@ def compute_gae(rewards: torch.Tensor, # compute policy loss ratio = torch.exp(next_log_prob - sampled_log_prob) surrogate = sampled_advantages * ratio - surrogate_clipped = sampled_advantages * torch.clip(ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip) + surrogate_clipped = sampled_advantages * torch.clip( + ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip + ) policy_loss = -torch.min(surrogate, surrogate_clipped).mean() @@ -510,54 +579,75 @@ def compute_gae(rewards: torch.Tensor, predicted_values, _, _ = self.value.act({"states": sampled_states}, role="value") if self._clip_predicted_values: - predicted_values = sampled_values + torch.clip(predicted_values - sampled_values, - min=-self._value_clip, - max=self._value_clip) + predicted_values = sampled_values + torch.clip( + predicted_values - sampled_values, min=-self._value_clip, max=self._value_clip + ) value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values) # compute discriminator loss if self._discriminator_batch_size: - sampled_amp_states = self._amp_state_preprocessor(sampled_amp_states[0:self._discriminator_batch_size], train=True) + sampled_amp_states = self._amp_state_preprocessor( + sampled_amp_states[0 : self._discriminator_batch_size], train=True + ) sampled_amp_replay_states = self._amp_state_preprocessor( - sampled_replay_batches[batch_index][0][0:self._discriminator_batch_size], train=True) + sampled_replay_batches[batch_index][0][0 : self._discriminator_batch_size], train=True + ) sampled_amp_motion_states = self._amp_state_preprocessor( - sampled_motion_batches[batch_index][0][0:self._discriminator_batch_size], train=True) + sampled_motion_batches[batch_index][0][0 : self._discriminator_batch_size], train=True + ) else: sampled_amp_states = self._amp_state_preprocessor(sampled_amp_states, train=True) - sampled_amp_replay_states = self._amp_state_preprocessor(sampled_replay_batches[batch_index][0], train=True) - sampled_amp_motion_states = self._amp_state_preprocessor(sampled_motion_batches[batch_index][0], train=True) + sampled_amp_replay_states = self._amp_state_preprocessor( + sampled_replay_batches[batch_index][0], train=True + ) + sampled_amp_motion_states = self._amp_state_preprocessor( + sampled_motion_batches[batch_index][0], train=True + ) sampled_amp_motion_states.requires_grad_(True) amp_logits, _, _ = self.discriminator.act({"states": sampled_amp_states}, role="discriminator") - amp_replay_logits, _, _ = self.discriminator.act({"states": sampled_amp_replay_states}, role="discriminator") - amp_motion_logits, _, _ = self.discriminator.act({"states": sampled_amp_motion_states}, role="discriminator") + amp_replay_logits, _, _ = self.discriminator.act( + {"states": sampled_amp_replay_states}, role="discriminator" + ) + amp_motion_logits, _, _ = self.discriminator.act( + {"states": sampled_amp_motion_states}, role="discriminator" + ) amp_cat_logits = torch.cat([amp_logits, amp_replay_logits], dim=0) # discriminator prediction loss - discriminator_loss = 0.5 * (nn.BCEWithLogitsLoss()(amp_cat_logits, torch.zeros_like(amp_cat_logits)) \ - + torch.nn.BCEWithLogitsLoss()(amp_motion_logits, torch.ones_like(amp_motion_logits))) + discriminator_loss = 0.5 * ( + nn.BCEWithLogitsLoss()(amp_cat_logits, torch.zeros_like(amp_cat_logits)) + + torch.nn.BCEWithLogitsLoss()(amp_motion_logits, torch.ones_like(amp_motion_logits)) + ) # discriminator logit regularization if self._discriminator_logit_regularization_scale: logit_weights = torch.flatten(list(self.discriminator.modules())[-1].weight) - discriminator_loss += self._discriminator_logit_regularization_scale * torch.sum(torch.square(logit_weights)) + discriminator_loss += self._discriminator_logit_regularization_scale * torch.sum( + torch.square(logit_weights) + ) # discriminator gradient penalty if self._discriminator_gradient_penalty_scale: - amp_motion_gradient = torch.autograd.grad(amp_motion_logits, - sampled_amp_motion_states, - grad_outputs=torch.ones_like(amp_motion_logits), - create_graph=True, - retain_graph=True, - only_inputs=True) + amp_motion_gradient = torch.autograd.grad( + amp_motion_logits, + sampled_amp_motion_states, + grad_outputs=torch.ones_like(amp_motion_logits), + create_graph=True, + retain_graph=True, + only_inputs=True, + ) gradient_penalty = torch.sum(torch.square(amp_motion_gradient[0]), dim=-1).mean() discriminator_loss += self._discriminator_gradient_penalty_scale * gradient_penalty # discriminator weight decay if self._discriminator_weight_decay_scale: - weights = [torch.flatten(module.weight) for module in self.discriminator.modules() \ - if isinstance(module, torch.nn.Linear)] + weights = [ + torch.flatten(module.weight) + for module in self.discriminator.modules() + if isinstance(module, torch.nn.Linear) + ] weight_decay = torch.sum(torch.square(torch.cat(weights, dim=-1))) discriminator_loss += self._discriminator_weight_decay_scale * weight_decay @@ -571,9 +661,12 @@ def compute_gae(rewards: torch.Tensor, self.value.reduce_parameters() self.discriminator.reduce_parameters() if self._grad_norm_clip > 0: - nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), - self.value.parameters(), - self.discriminator.parameters()), self._grad_norm_clip) + nn.utils.clip_grad_norm_( + itertools.chain( + self.policy.parameters(), self.value.parameters(), self.discriminator.parameters() + ), + self._grad_norm_clip, + ) self.optimizer.step() # update cumulative losses @@ -594,8 +687,12 @@ def compute_gae(rewards: torch.Tensor, self.track_data("Loss / Policy loss", cumulative_policy_loss / (self._learning_epochs * self._mini_batches)) self.track_data("Loss / Value loss", cumulative_value_loss / (self._learning_epochs * self._mini_batches)) if self._entropy_loss_scale: - self.track_data("Loss / Entropy loss", cumulative_entropy_loss / (self._learning_epochs * self._mini_batches)) - self.track_data("Loss / Discriminator loss", cumulative_discriminator_loss / (self._learning_epochs * self._mini_batches)) + self.track_data( + "Loss / Entropy loss", cumulative_entropy_loss / (self._learning_epochs * self._mini_batches) + ) + self.track_data( + "Loss / Discriminator loss", cumulative_discriminator_loss / (self._learning_epochs * self._mini_batches) + ) self.track_data("Policy / Standard deviation", self.policy.distribution(role="policy").stddev.mean().item()) diff --git a/skrl/agents/torch/base.py b/skrl/agents/torch/base.py index d7c29e19..4cd6244f 100644 --- a/skrl/agents/torch/base.py +++ b/skrl/agents/torch/base.py @@ -17,13 +17,15 @@ class Agent: - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Base class that represent a RL agent :param models: Models used by the agent @@ -46,7 +48,9 @@ def __init__(self, self.observation_space = observation_space self.action_space = action_space self.cfg = cfg if cfg is not None else {} - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device) + self.device = ( + torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device) + ) if type(memory) is list: self.memory = memory[0] @@ -74,7 +78,7 @@ def __init__(self, self.checkpoint_modules = {} self.checkpoint_interval = self.cfg.get("experiment", {}).get("checkpoint_interval", "auto") self.checkpoint_store_separately = self.cfg.get("experiment", {}).get("store_separately", False) - self.checkpoint_best_modules = {"timestep": 0, "reward": -2 ** 31, "saved": False, "modules": {}} + self.checkpoint_best_modules = {"timestep": 0, "reward": -(2**31), "saved": False, "modules": {}} # experiment directory directory = self.cfg.get("experiment", {}).get("directory", "") @@ -82,7 +86,9 @@ def __init__(self, if not directory: directory = os.path.join(os.getcwd(), "runs") if not experiment_name: - experiment_name = "{}_{}".format(datetime.datetime.now().strftime("%y-%m-%d_%H-%M-%S-%f"), self.__class__.__name__) + experiment_name = "{}_{}".format( + datetime.datetime.now().strftime("%y-%m-%d_%H-%M-%S-%f"), self.__class__.__name__ + ) self.experiment_dir = os.path.join(directory, experiment_name) def __str__(self) -> str: @@ -149,7 +155,7 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: models_cfg = {k: v.net._modules for (k, v) in self.models.items()} except AttributeError: models_cfg = {k: v._modules for (k, v) in self.models.items()} - wandb_config={**self.cfg, **trainer_cfg, **models_cfg} + wandb_config = {**self.cfg, **trainer_cfg, **models_cfg} # set default values wandb_kwargs = copy.deepcopy(self.cfg.get("experiment", {}).get("wandb_kwargs", {})) wandb_kwargs.setdefault("name", os.path.split(self.experiment_dir)[-1]) @@ -158,6 +164,7 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: wandb_kwargs["config"].update(wandb_config) # init Weights & Biases import wandb + wandb.init(**wandb_kwargs) # main entry to log data for consumption and visualization by TensorBoard @@ -218,8 +225,10 @@ def write_checkpoint(self, timestep: int, timesteps: int) -> None: # separated modules if self.checkpoint_store_separately: for name, module in self.checkpoint_modules.items(): - torch.save(self._get_internal_value(module), - os.path.join(self.experiment_dir, "checkpoints", f"{name}_{tag}.pt")) + torch.save( + self._get_internal_value(module), + os.path.join(self.experiment_dir, "checkpoints", f"{name}_{tag}.pt"), + ) # whole agent else: modules = {} @@ -232,8 +241,10 @@ def write_checkpoint(self, timestep: int, timesteps: int) -> None: # separated modules if self.checkpoint_store_separately: for name, module in self.checkpoint_modules.items(): - torch.save(self.checkpoint_best_modules["modules"][name], - os.path.join(self.experiment_dir, "checkpoints", f"best_{name}.pt")) + torch.save( + self.checkpoint_best_modules["modules"][name], + os.path.join(self.experiment_dir, "checkpoints", f"best_{name}.pt"), + ) # whole agent else: modules = {} @@ -242,10 +253,7 @@ def write_checkpoint(self, timestep: int, timesteps: int) -> None: torch.save(modules, os.path.join(self.experiment_dir, "checkpoints", "best_agent.pt")) self.checkpoint_best_modules["saved"] = True - def act(self, - states: torch.Tensor, - timestep: int, - timesteps: int) -> torch.Tensor: + def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tensor: """Process the environment's states to make a decision (actions) using the main policy :param states: Environment's states @@ -262,16 +270,18 @@ def act(self, """ raise NotImplementedError - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory (to be implemented by the inheriting classes) Inheriting classes must call this method to record episode information (rewards, timesteps, etc.). @@ -391,11 +401,13 @@ def load(self, path: str) -> None: else: logger.warning(f"Cannot load the {name} module. The agent doesn't have such an instance") - def migrate(self, - path: str, - name_map: Mapping[str, Mapping[str, str]] = {}, - auto_mapping: bool = True, - verbose: bool = False) -> bool: + def migrate( + self, + path: str, + name_map: Mapping[str, Mapping[str, str]] = {}, + auto_mapping: bool = True, + verbose: bool = False, + ) -> bool: """Migrate the specified external checkpoint to the current agent The final storage device is determined by the constructor of the agent. @@ -621,10 +633,12 @@ def migrate(self, if module not in ["state_preprocessor", "value_preprocessor", "optimizer"] and hasattr(module, "migrate"): if verbose: logger.info(f"Model: {name} ({type(module).__name__})") - status *= module.migrate(state_dict=checkpoint["model"], - name_map=name_map.get(name, {}), - auto_mapping=auto_mapping, - verbose=verbose) + status *= module.migrate( + state_dict=checkpoint["model"], + name_map=name_map.get(name, {}), + auto_mapping=auto_mapping, + verbose=verbose, + ) self.set_mode("eval") return bool(status) @@ -652,12 +666,14 @@ def post_interaction(self, timestep: int, timesteps: int) -> None: # update best models and write checkpoints if timestep > 1 and self.checkpoint_interval > 0 and not timestep % self.checkpoint_interval: # update best models - reward = np.mean(self.tracking_data.get("Reward / Total reward (mean)", -2 ** 31)) + reward = np.mean(self.tracking_data.get("Reward / Total reward (mean)", -(2**31))) if reward > self.checkpoint_best_modules["reward"]: self.checkpoint_best_modules["timestep"] = timestep self.checkpoint_best_modules["reward"] = reward self.checkpoint_best_modules["saved"] = False - self.checkpoint_best_modules["modules"] = {k: copy.deepcopy(self._get_internal_value(v)) for k, v in self.checkpoint_modules.items()} + self.checkpoint_best_modules["modules"] = { + k: copy.deepcopy(self._get_internal_value(v)) for k, v in self.checkpoint_modules.items() + } # write checkpoints self.write_checkpoint(timestep, timesteps) diff --git a/skrl/agents/torch/cem/cem.py b/skrl/agents/torch/cem/cem.py index 735250e4..4daf2ee7 100644 --- a/skrl/agents/torch/cem/cem.py +++ b/skrl/agents/torch/cem/cem.py @@ -49,13 +49,15 @@ class CEM(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Cross-Entropy Method (CEM) https://ieeexplore.ieee.org/abstract/document/6796865/ @@ -80,12 +82,14 @@ def __init__(self, """ _cfg = copy.deepcopy(CEM_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -116,7 +120,9 @@ def __init__(self, if self.policy is not None: self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._learning_rate) if self._learning_rate_scheduler is not None: - self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["optimizer"] = self.optimizer @@ -128,8 +134,7 @@ def __init__(self, self._state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) # create tensors in memory @@ -165,16 +170,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens # sample stochastic actions return self.policy.act({"states": states}, role="policy") - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -196,18 +203,32 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) # reward shaping if self._rewards_shaper is not None: rewards = self._rewards_shaper(rewards, timestep, timesteps) if self.memory is not None: - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) # track episodes internally if self._rollout: @@ -264,9 +285,14 @@ def _update(self, timestep: int, timesteps: int) -> None: for e in range(sampled_rewards.size(-1)): for i, j in zip(self._episode_tracking[e][:-1], self._episode_tracking[e][1:]): limits.append([e + i, e + j]) - rewards = sampled_rewards[e + i: e + j] - returns.append(torch.sum(rewards * self._discount_factor ** \ - torch.arange(rewards.size(0), device=rewards.device).flip(-1).view(rewards.size()))) + rewards = sampled_rewards[e + i : e + j] + returns.append( + torch.sum( + rewards + * self._discount_factor + ** torch.arange(rewards.size(0), device=rewards.device).flip(-1).view(rewards.size()) + ) + ) if not len(returns): logger.warning("No returns to update. Consider increasing the number of rollouts") @@ -277,8 +303,8 @@ def _update(self, timestep: int, timesteps: int) -> None: # get elite states and actions indexes = torch.nonzero(returns >= return_threshold) - elite_states = torch.cat([sampled_states[limits[i][0]:limits[i][1]] for i in indexes[:, 0]], dim=0) - elite_actions = torch.cat([sampled_actions[limits[i][0]:limits[i][1]] for i in indexes[:, 0]], dim=0) + elite_states = torch.cat([sampled_states[limits[i][0] : limits[i][1]] for i in indexes[:, 0]], dim=0) + elite_actions = torch.cat([sampled_actions[limits[i][0] : limits[i][1]] for i in indexes[:, 0]], dim=0) # compute scores for the elite states _, _, outputs = self.policy.act({"states": elite_states}, role="policy") diff --git a/skrl/agents/torch/ddpg/ddpg.py b/skrl/agents/torch/ddpg/ddpg.py index 16f094d8..25983eba 100644 --- a/skrl/agents/torch/ddpg/ddpg.py +++ b/skrl/agents/torch/ddpg/ddpg.py @@ -61,13 +61,15 @@ class DDPG(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Deep Deterministic Policy Gradient (DDPG) https://arxiv.org/abs/1509.02971 @@ -92,12 +94,14 @@ def __init__(self, """ _cfg = copy.deepcopy(DDPG_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -120,7 +124,7 @@ def __init__(self, self.critic.broadcast_parameters() if self.target_policy is not None and self.target_critic is not None: - # freeze target networks with respect to optimizers (update via .update_parameters()) + # freeze target networks with respect to optimizers (update via .update_parameters()) self.target_policy.freeze_parameters(True) self.target_critic.freeze_parameters(True) @@ -158,8 +162,12 @@ def __init__(self, self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._actor_learning_rate) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self._critic_learning_rate) if self._learning_rate_scheduler is not None: - self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) - self.critic_scheduler = self._learning_rate_scheduler(self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.policy_scheduler = self._learning_rate_scheduler( + self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) + self.critic_scheduler = self._learning_rate_scheduler( + self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer self.checkpoint_modules["critic_optimizer"] = self.critic_optimizer @@ -172,8 +180,7 @@ def __init__(self, self._state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -224,9 +231,9 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens # apply exploration noise if timestep <= self._exploration_timesteps: - scale = (1 - timestep / self._exploration_timesteps) \ - * (self._exploration_initial_scale - self._exploration_final_scale) \ - + self._exploration_final_scale + scale = (1 - timestep / self._exploration_timesteps) * ( + self._exploration_initial_scale - self._exploration_final_scale + ) + self._exploration_final_scale noises.mul_(scale) # modify actions @@ -246,16 +253,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, None, outputs - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -277,7 +286,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: # reward shaping @@ -285,11 +296,23 @@ def record_transition(self, rewards = self._rewards_shaper(rewards, timestep, timesteps) # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -330,8 +353,9 @@ def _update(self, timestep: int, timesteps: int) -> None: for gradient_step in range(self._gradient_steps): # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = self.memory.sample( + names=self._tensors_names, batch_size=self._batch_size + )[0] sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) @@ -340,11 +364,15 @@ def _update(self, timestep: int, timesteps: int) -> None: with torch.no_grad(): next_actions, _, _ = self.target_policy.act({"states": sampled_next_states}, role="target_policy") - target_q_values, _, _ = self.target_critic.act({"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic") + target_q_values, _, _ = self.target_critic.act( + {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic" + ) target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values # compute critic loss - critic_values, _, _ = self.critic.act({"states": sampled_states, "taken_actions": sampled_actions}, role="critic") + critic_values, _, _ = self.critic.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="critic" + ) critic_loss = F.mse_loss(critic_values, target_values) diff --git a/skrl/agents/torch/ddpg/ddpg_rnn.py b/skrl/agents/torch/ddpg/ddpg_rnn.py index 6d0dd829..89ddebca 100644 --- a/skrl/agents/torch/ddpg/ddpg_rnn.py +++ b/skrl/agents/torch/ddpg/ddpg_rnn.py @@ -61,13 +61,15 @@ class DDPG_RNN(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Deep Deterministic Policy Gradient (DDPG) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.) https://arxiv.org/abs/1509.02971 @@ -92,12 +94,14 @@ def __init__(self, """ _cfg = copy.deepcopy(DDPG_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -120,7 +124,7 @@ def __init__(self, self.critic.broadcast_parameters() if self.target_policy is not None and self.target_critic is not None: - # freeze target networks with respect to optimizers (update via .update_parameters()) + # freeze target networks with respect to optimizers (update via .update_parameters()) self.target_policy.freeze_parameters(True) self.target_critic.freeze_parameters(True) @@ -158,8 +162,12 @@ def __init__(self, self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._actor_learning_rate) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self._critic_learning_rate) if self._learning_rate_scheduler is not None: - self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) - self.critic_scheduler = self._learning_rate_scheduler(self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.policy_scheduler = self._learning_rate_scheduler( + self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) + self.critic_scheduler = self._learning_rate_scheduler( + self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer self.checkpoint_modules["critic_optimizer"] = self.critic_optimizer @@ -172,8 +180,7 @@ def __init__(self, self._state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -199,7 +206,9 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self._rnn = True # create tensors in memory if self.memory is not None: - self.memory.create_tensor(name=f"rnn_policy_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True) + self.memory.create_tensor( + name=f"rnn_policy_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True + ) self._rnn_tensors_names.append(f"rnn_policy_{i}") # default RNN states self._rnn_initial_states["policy"].append(torch.zeros(size, dtype=torch.float32, device=self.device)) @@ -246,9 +255,9 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens # apply exploration noise if timestep <= self._exploration_timesteps: - scale = (1 - timestep / self._exploration_timesteps) \ - * (self._exploration_initial_scale - self._exploration_final_scale) \ - + self._exploration_final_scale + scale = (1 - timestep / self._exploration_timesteps) * ( + self._exploration_initial_scale - self._exploration_final_scale + ) + self._exploration_final_scale noises.mul_(scale) # modify actions @@ -268,16 +277,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, None, outputs - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -299,7 +310,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: # reward shaping @@ -309,14 +322,30 @@ def record_transition(self, # package RNN states rnn_states = {} if self._rnn: - rnn_states.update({f"rnn_policy_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["policy"])}) + rnn_states.update( + {f"rnn_policy_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["policy"])} + ) # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, **rnn_states) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + **rnn_states, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, **rnn_states) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + **rnn_states, + ) # update RNN states if self._rnn: @@ -367,12 +396,15 @@ def _update(self, timestep: int, timesteps: int) -> None: for gradient_step in range(self._gradient_steps): # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = self.memory.sample( + names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length + )[0] rnn_policy = {} if self._rnn: - sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] + sampled_rnn = self.memory.sample_by_index( + names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes() + )[0] rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} sampled_states = self._state_preprocessor(sampled_states, train=True) @@ -380,13 +412,19 @@ def _update(self, timestep: int, timesteps: int) -> None: # compute target values with torch.no_grad(): - next_actions, _, _ = self.target_policy.act({"states": sampled_next_states, **rnn_policy}, role="target_policy") + next_actions, _, _ = self.target_policy.act( + {"states": sampled_next_states, **rnn_policy}, role="target_policy" + ) - target_q_values, _, _ = self.target_critic.act({"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, role="target_critic") + target_q_values, _, _ = self.target_critic.act( + {"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, role="target_critic" + ) target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values # compute critic loss - critic_values, _, _ = self.critic.act({"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic") + critic_values, _, _ = self.critic.act( + {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic" + ) critic_loss = F.mse_loss(critic_values, target_values) @@ -401,7 +439,9 @@ def _update(self, timestep: int, timesteps: int) -> None: # compute policy (actor) loss actions, _, _ = self.policy.act({"states": sampled_states, **rnn_policy}, role="policy") - critic_values, _, _ = self.critic.act({"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic") + critic_values, _, _ = self.critic.act( + {"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic" + ) policy_loss = -critic_values.mean() diff --git a/skrl/agents/torch/dqn/ddqn.py b/skrl/agents/torch/dqn/ddqn.py index c2e84d02..fdb73e84 100644 --- a/skrl/agents/torch/dqn/ddqn.py +++ b/skrl/agents/torch/dqn/ddqn.py @@ -60,13 +60,15 @@ class DDQN(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Double Deep Q-Network (DDQN) https://ojs.aaai.org/index.php/AAAI/article/view/10295 @@ -91,12 +93,14 @@ def __init__(self, """ _cfg = copy.deepcopy(DDQN_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.q_network = self.models.get("q_network", None) @@ -147,7 +151,9 @@ def __init__(self, if self.q_network is not None: self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=self._learning_rate) if self._learning_rate_scheduler is not None: - self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["optimizer"] = self.optimizer @@ -159,8 +165,7 @@ def __init__(self, self._state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) # create tensors in memory @@ -189,7 +194,11 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens states = self._state_preprocessor(states) if not self._exploration_timesteps: - return torch.argmax(self.q_network.act({"states": states}, role="q_network")[0], dim=1, keepdim=True), None, None + return ( + torch.argmax(self.q_network.act({"states": states}, role="q_network")[0], dim=1, keepdim=True), + None, + None, + ) # sample random actions actions = self.q_network.random_act({"states": states}, role="q_network")[0] @@ -197,28 +206,33 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, None, None # sample actions with epsilon-greedy policy - epsilon = self._exploration_final_epsilon + (self._exploration_initial_epsilon - self._exploration_final_epsilon) \ - * math.exp(-1.0 * timestep / self._exploration_timesteps) + epsilon = self._exploration_final_epsilon + ( + self._exploration_initial_epsilon - self._exploration_final_epsilon + ) * math.exp(-1.0 * timestep / self._exploration_timesteps) indexes = (torch.rand(states.shape[0], device=self.device) >= epsilon).nonzero().view(-1) if indexes.numel(): - actions[indexes] = torch.argmax(self.q_network.act({"states": states[indexes]}, role="q_network")[0], dim=1, keepdim=True) + actions[indexes] = torch.argmax( + self.q_network.act({"states": states[indexes]}, role="q_network")[0], dim=1, keepdim=True + ) # record epsilon self.track_data("Exploration / Exploration epsilon", epsilon) return actions, None, None - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -240,18 +254,32 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: # reward shaping if self._rewards_shaper is not None: rewards = self._rewards_shaper(rewards, timestep, timesteps) - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -290,22 +318,32 @@ def _update(self, timestep: int, timesteps: int) -> None: for gradient_step in range(self._gradient_steps): # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = self.memory.sample( + names=self.tensors_names, batch_size=self._batch_size + )[0] sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) # compute target values with torch.no_grad(): - next_q_values, _, _ = self.target_q_network.act({"states": sampled_next_states}, role="target_q_network") - - target_q_values = torch.gather(next_q_values, dim=1, index=torch.argmax(self.q_network.act({"states": sampled_next_states}, \ - role="q_network")[0], dim=1, keepdim=True)) + next_q_values, _, _ = self.target_q_network.act( + {"states": sampled_next_states}, role="target_q_network" + ) + + target_q_values = torch.gather( + next_q_values, + dim=1, + index=torch.argmax( + self.q_network.act({"states": sampled_next_states}, role="q_network")[0], dim=1, keepdim=True + ), + ) target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values # compute Q-network loss - q_values = torch.gather(self.q_network.act({"states": sampled_states}, role="q_network")[0], dim=1, index=sampled_actions.long()) + q_values = torch.gather( + self.q_network.act({"states": sampled_states}, role="q_network")[0], dim=1, index=sampled_actions.long() + ) q_network_loss = F.mse_loss(q_values, target_values) diff --git a/skrl/agents/torch/dqn/dqn.py b/skrl/agents/torch/dqn/dqn.py index 5fd5ba48..318b56a3 100644 --- a/skrl/agents/torch/dqn/dqn.py +++ b/skrl/agents/torch/dqn/dqn.py @@ -60,13 +60,15 @@ class DQN(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Deep Q-Network (DQN) https://arxiv.org/abs/1312.5602 @@ -91,12 +93,14 @@ def __init__(self, """ _cfg = copy.deepcopy(DQN_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.q_network = self.models.get("q_network", None) @@ -147,7 +151,9 @@ def __init__(self, if self.q_network is not None: self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=self._learning_rate) if self._learning_rate_scheduler is not None: - self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["optimizer"] = self.optimizer @@ -159,8 +165,7 @@ def __init__(self, self._state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) # create tensors in memory @@ -189,7 +194,11 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens states = self._state_preprocessor(states) if not self._exploration_timesteps: - return torch.argmax(self.q_network.act({"states": states}, role="q_network")[0], dim=1, keepdim=True), None, None + return ( + torch.argmax(self.q_network.act({"states": states}, role="q_network")[0], dim=1, keepdim=True), + None, + None, + ) # sample random actions actions = self.q_network.random_act({"states": states}, role="q_network")[0] @@ -197,28 +206,33 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, None, None # sample actions with epsilon-greedy policy - epsilon = self._exploration_final_epsilon + (self._exploration_initial_epsilon - self._exploration_final_epsilon) \ - * math.exp(-1.0 * timestep / self._exploration_timesteps) + epsilon = self._exploration_final_epsilon + ( + self._exploration_initial_epsilon - self._exploration_final_epsilon + ) * math.exp(-1.0 * timestep / self._exploration_timesteps) indexes = (torch.rand(states.shape[0], device=self.device) >= epsilon).nonzero().view(-1) if indexes.numel(): - actions[indexes] = torch.argmax(self.q_network.act({"states": states[indexes]}, role="q_network")[0], dim=1, keepdim=True) + actions[indexes] = torch.argmax( + self.q_network.act({"states": states[indexes]}, role="q_network")[0], dim=1, keepdim=True + ) # record epsilon self.track_data("Exploration / Exploration epsilon", epsilon) return actions, None, None - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -240,18 +254,32 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: # reward shaping if self._rewards_shaper is not None: rewards = self._rewards_shaper(rewards, timestep, timesteps) - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -290,22 +318,26 @@ def _update(self, timestep: int, timesteps: int) -> None: for gradient_step in range(self._gradient_steps): # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = self.memory.sample( + names=self.tensors_names, batch_size=self._batch_size + )[0] sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) # compute target values with torch.no_grad(): - next_q_values, _, _ = self.target_q_network.act({"states": sampled_next_states}, role="target_q_network") + next_q_values, _, _ = self.target_q_network.act( + {"states": sampled_next_states}, role="target_q_network" + ) target_q_values = torch.max(next_q_values, dim=-1, keepdim=True)[0] target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values # compute Q-network loss - q_values = torch.gather(self.q_network.act({"states": sampled_states}, role="q_network")[0], - dim=1, index=sampled_actions.long()) + q_values = torch.gather( + self.q_network.act({"states": sampled_states}, role="q_network")[0], dim=1, index=sampled_actions.long() + ) q_network_loss = F.mse_loss(q_values, target_values) diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py index 26febbf2..e6be90c4 100644 --- a/skrl/agents/torch/ppo/ppo.py +++ b/skrl/agents/torch/ppo/ppo.py @@ -69,13 +69,15 @@ class PPO(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Proximal Policy Optimization (PPO) https://arxiv.org/abs/1707.06347 @@ -100,12 +102,14 @@ def __init__(self, """ _cfg = copy.deepcopy(PPO_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -165,10 +169,13 @@ def __init__(self, if self.policy is self.value: self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._learning_rate) else: - self.optimizer = torch.optim.Adam(itertools.chain(self.policy.parameters(), self.value.parameters()), - lr=self._learning_rate) + self.optimizer = torch.optim.Adam( + itertools.chain(self.policy.parameters(), self.value.parameters()), lr=self._learning_rate + ) if self._learning_rate_scheduler is not None: - self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["optimizer"] = self.optimizer @@ -186,8 +193,7 @@ def __init__(self, self._value_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -234,16 +240,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, log_prob, outputs - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -265,7 +273,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: self._current_next_states = next_states @@ -284,11 +294,27 @@ def record_transition(self, rewards += self._discount_factor * values * truncated # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -325,12 +351,15 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - def compute_gae(rewards: torch.Tensor, - dones: torch.Tensor, - values: torch.Tensor, - next_values: torch.Tensor, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> torch.Tensor: + + def compute_gae( + rewards: torch.Tensor, + dones: torch.Tensor, + values: torch.Tensor, + next_values: torch.Tensor, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, + ) -> torch.Tensor: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -357,7 +386,11 @@ def compute_gae(rewards: torch.Tensor, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else last_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] + - values[i] + + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -369,17 +402,21 @@ def compute_gae(rewards: torch.Tensor, # compute returns and advantages with torch.no_grad(): self.value.train(False) - last_values, _, _ = self.value.act({"states": self._state_preprocessor(self._current_next_states.float())}, role="value") + last_values, _, _ = self.value.act( + {"states": self._state_preprocessor(self._current_next_states.float())}, role="value" + ) self.value.train(True) last_values = self._value_preprocessor(last_values, inverse=True) values = self.memory.get_tensor_by_name("values") - returns, advantages = compute_gae(rewards=self.memory.get_tensor_by_name("rewards"), - dones=self.memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor, - lambda_coefficient=self._lambda) + returns, advantages = compute_gae( + rewards=self.memory.get_tensor_by_name("rewards"), + dones=self.memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor, + lambda_coefficient=self._lambda, + ) self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True)) self.memory.set_tensor_by_name("returns", self._value_preprocessor(returns, train=True)) @@ -397,13 +434,22 @@ def compute_gae(rewards: torch.Tensor, kl_divergences = [] # mini-batches loop - for sampled_states, sampled_actions, sampled_log_prob, sampled_values, sampled_returns, sampled_advantages in sampled_batches: + for ( + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_values, + sampled_returns, + sampled_advantages, + ) in sampled_batches: with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): sampled_states = self._state_preprocessor(sampled_states, train=not epoch) - _, next_log_prob, _ = self.policy.act({"states": sampled_states, "taken_actions": sampled_actions}, role="policy") + _, next_log_prob, _ = self.policy.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="policy" + ) # compute approximate KL divergence with torch.no_grad(): @@ -424,7 +470,9 @@ def compute_gae(rewards: torch.Tensor, # compute policy loss ratio = torch.exp(next_log_prob - sampled_log_prob) surrogate = sampled_advantages * ratio - surrogate_clipped = sampled_advantages * torch.clip(ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip) + surrogate_clipped = sampled_advantages * torch.clip( + ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip + ) policy_loss = -torch.min(surrogate, surrogate_clipped).mean() @@ -432,9 +480,9 @@ def compute_gae(rewards: torch.Tensor, predicted_values, _, _ = self.value.act({"states": sampled_states}, role="value") if self._clip_predicted_values: - predicted_values = sampled_values + torch.clip(predicted_values - sampled_values, - min=-self._value_clip, - max=self._value_clip) + predicted_values = sampled_values + torch.clip( + predicted_values - sampled_values, min=-self._value_clip, max=self._value_clip + ) value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values) # optimization step @@ -451,7 +499,9 @@ def compute_gae(rewards: torch.Tensor, if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) else: - nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip) + nn.utils.clip_grad_norm_( + itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip + ) self._scaler.step(self.optimizer) self._scaler.update() @@ -478,7 +528,9 @@ def compute_gae(rewards: torch.Tensor, self.track_data("Loss / Policy loss", cumulative_policy_loss / (self._learning_epochs * self._mini_batches)) self.track_data("Loss / Value loss", cumulative_value_loss / (self._learning_epochs * self._mini_batches)) if self._entropy_loss_scale: - self.track_data("Loss / Entropy loss", cumulative_entropy_loss / (self._learning_epochs * self._mini_batches)) + self.track_data( + "Loss / Entropy loss", cumulative_entropy_loss / (self._learning_epochs * self._mini_batches) + ) self.track_data("Policy / Standard deviation", self.policy.distribution(role="policy").stddev.mean().item()) diff --git a/skrl/agents/torch/ppo/ppo_rnn.py b/skrl/agents/torch/ppo/ppo_rnn.py index f19ca3c2..7259507c 100644 --- a/skrl/agents/torch/ppo/ppo_rnn.py +++ b/skrl/agents/torch/ppo/ppo_rnn.py @@ -67,13 +67,15 @@ class PPO_RNN(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Proximal Policy Optimization (PPO) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.) https://arxiv.org/abs/1707.06347 @@ -98,12 +100,14 @@ def __init__(self, """ _cfg = copy.deepcopy(PPO_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -157,10 +161,13 @@ def __init__(self, if self.policy is self.value: self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._learning_rate) else: - self.optimizer = torch.optim.Adam(itertools.chain(self.policy.parameters(), self.value.parameters()), - lr=self._learning_rate) + self.optimizer = torch.optim.Adam( + itertools.chain(self.policy.parameters(), self.value.parameters()), lr=self._learning_rate + ) if self._learning_rate_scheduler is not None: - self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["optimizer"] = self.optimizer @@ -178,8 +185,7 @@ def __init__(self, self._value_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -209,7 +215,9 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self._rnn = True # create tensors in memory if self.memory is not None: - self.memory.create_tensor(name=f"rnn_policy_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True) + self.memory.create_tensor( + name=f"rnn_policy_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True + ) self._rnn_tensors_names.append(f"rnn_policy_{i}") # default RNN states self._rnn_initial_states["policy"].append(torch.zeros(size, dtype=torch.float32, device=self.device)) @@ -223,7 +231,9 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self._rnn = True # create tensors in memory if self.memory is not None: - self.memory.create_tensor(name=f"rnn_value_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True) + self.memory.create_tensor( + name=f"rnn_value_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True + ) self._rnn_tensors_names.append(f"rnn_value_{i}") # default RNN states self._rnn_initial_states["value"].append(torch.zeros(size, dtype=torch.float32, device=self.device)) @@ -261,16 +271,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, log_prob, outputs - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -292,7 +304,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: self._current_next_states = next_states @@ -313,20 +327,44 @@ def record_transition(self, # package RNN states rnn_states = {} if self._rnn: - rnn_states.update({f"rnn_policy_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["policy"])}) + rnn_states.update( + {f"rnn_policy_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["policy"])} + ) if self.policy is not self.value: - rnn_states.update({f"rnn_value_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["value"])}) + rnn_states.update( + {f"rnn_value_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["value"])} + ) # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values, **rnn_states) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + **rnn_states, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values, **rnn_states) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + **rnn_states, + ) # update RNN states if self._rnn: - self._rnn_final_states["value"] = self._rnn_final_states["policy"] if self.policy is self.value else outputs.get("rnn", []) + self._rnn_final_states["value"] = ( + self._rnn_final_states["policy"] if self.policy is self.value else outputs.get("rnn", []) + ) # reset states if the episodes have ended finished_episodes = terminated.nonzero(as_tuple=False) @@ -374,12 +412,15 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - def compute_gae(rewards: torch.Tensor, - dones: torch.Tensor, - values: torch.Tensor, - next_values: torch.Tensor, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> torch.Tensor: + + def compute_gae( + rewards: torch.Tensor, + dones: torch.Tensor, + values: torch.Tensor, + next_values: torch.Tensor, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, + ) -> torch.Tensor: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -406,7 +447,11 @@ def compute_gae(rewards: torch.Tensor, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else last_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] + - values[i] + + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -419,28 +464,38 @@ def compute_gae(rewards: torch.Tensor, with torch.no_grad(): self.value.train(False) rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {} - last_values, _, _ = self.value.act({"states": self._state_preprocessor(self._current_next_states.float()), **rnn}, role="value") + last_values, _, _ = self.value.act( + {"states": self._state_preprocessor(self._current_next_states.float()), **rnn}, role="value" + ) self.value.train(True) last_values = self._value_preprocessor(last_values, inverse=True) values = self.memory.get_tensor_by_name("values") - returns, advantages = compute_gae(rewards=self.memory.get_tensor_by_name("rewards"), - dones=self.memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor, - lambda_coefficient=self._lambda) + returns, advantages = compute_gae( + rewards=self.memory.get_tensor_by_name("rewards"), + dones=self.memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor, + lambda_coefficient=self._lambda, + ) self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True)) self.memory.set_tensor_by_name("returns", self._value_preprocessor(returns, train=True)) self.memory.set_tensor_by_name("advantages", advantages) # sample mini-batches from memory - sampled_batches = self.memory.sample_all(names=self._tensors_names, mini_batches=self._mini_batches, sequence_length=self._rnn_sequence_length) + sampled_batches = self.memory.sample_all( + names=self._tensors_names, mini_batches=self._mini_batches, sequence_length=self._rnn_sequence_length + ) rnn_policy, rnn_value = {}, {} if self._rnn: - sampled_rnn_batches = self.memory.sample_all(names=self._rnn_tensors_names, mini_batches=self._mini_batches, sequence_length=self._rnn_sequence_length) + sampled_rnn_batches = self.memory.sample_all( + names=self._rnn_tensors_names, + mini_batches=self._mini_batches, + sequence_length=self._rnn_sequence_length, + ) cumulative_policy_loss = 0 cumulative_entropy_loss = 0 @@ -451,19 +506,46 @@ def compute_gae(rewards: torch.Tensor, kl_divergences = [] # mini-batches loop - for i, (sampled_states, sampled_actions, sampled_dones, sampled_log_prob, sampled_values, sampled_returns, sampled_advantages) in enumerate(sampled_batches): + for i, ( + sampled_states, + sampled_actions, + sampled_dones, + sampled_log_prob, + sampled_values, + sampled_returns, + sampled_advantages, + ) in enumerate(sampled_batches): if self._rnn: if self.policy is self.value: - rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn_batches[i]], "terminated": sampled_dones} + rnn_policy = { + "rnn": [s.transpose(0, 1) for s in sampled_rnn_batches[i]], + "terminated": sampled_dones, + } rnn_value = rnn_policy else: - rnn_policy = {"rnn": [s.transpose(0, 1) for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) if "policy" in n], "terminated": sampled_dones} - rnn_value = {"rnn": [s.transpose(0, 1) for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) if "value" in n], "terminated": sampled_dones} + rnn_policy = { + "rnn": [ + s.transpose(0, 1) + for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) + if "policy" in n + ], + "terminated": sampled_dones, + } + rnn_value = { + "rnn": [ + s.transpose(0, 1) + for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) + if "value" in n + ], + "terminated": sampled_dones, + } sampled_states = self._state_preprocessor(sampled_states, train=not epoch) - _, next_log_prob, _ = self.policy.act({"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="policy") + _, next_log_prob, _ = self.policy.act( + {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="policy" + ) # compute approximate KL divergence with torch.no_grad(): @@ -484,7 +566,9 @@ def compute_gae(rewards: torch.Tensor, # compute policy loss ratio = torch.exp(next_log_prob - sampled_log_prob) surrogate = sampled_advantages * ratio - surrogate_clipped = sampled_advantages * torch.clip(ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip) + surrogate_clipped = sampled_advantages * torch.clip( + ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip + ) policy_loss = -torch.min(surrogate, surrogate_clipped).mean() @@ -492,9 +576,9 @@ def compute_gae(rewards: torch.Tensor, predicted_values, _, _ = self.value.act({"states": sampled_states, **rnn_value}, role="value") if self._clip_predicted_values: - predicted_values = sampled_values + torch.clip(predicted_values - sampled_values, - min=-self._value_clip, - max=self._value_clip) + predicted_values = sampled_values + torch.clip( + predicted_values - sampled_values, min=-self._value_clip, max=self._value_clip + ) value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values) # optimization step @@ -508,7 +592,9 @@ def compute_gae(rewards: torch.Tensor, if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) else: - nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip) + nn.utils.clip_grad_norm_( + itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip + ) self.optimizer.step() # update cumulative losses @@ -533,7 +619,9 @@ def compute_gae(rewards: torch.Tensor, self.track_data("Loss / Policy loss", cumulative_policy_loss / (self._learning_epochs * self._mini_batches)) self.track_data("Loss / Value loss", cumulative_value_loss / (self._learning_epochs * self._mini_batches)) if self._entropy_loss_scale: - self.track_data("Loss / Entropy loss", cumulative_entropy_loss / (self._learning_epochs * self._mini_batches)) + self.track_data( + "Loss / Entropy loss", cumulative_entropy_loss / (self._learning_epochs * self._mini_batches) + ) self.track_data("Policy / Standard deviation", self.policy.distribution(role="policy").stddev.mean().item()) diff --git a/skrl/agents/torch/q_learning/q_learning.py b/skrl/agents/torch/q_learning/q_learning.py index ad247627..76124afc 100644 --- a/skrl/agents/torch/q_learning/q_learning.py +++ b/skrl/agents/torch/q_learning/q_learning.py @@ -39,13 +39,15 @@ class Q_LEARNING(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Q-learning https://www.academia.edu/3294050/Learning_from_delayed_rewards @@ -70,12 +72,14 @@ def __init__(self, """ _cfg = copy.deepcopy(Q_LEARNING_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -101,8 +105,7 @@ def __init__(self, self._current_dones = None def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tensor: @@ -125,16 +128,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens # sample actions from policy return self.policy.act({"states": states}, role="policy") - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -156,7 +161,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) # reward shaping if self._rewards_shaper is not None: @@ -169,11 +176,23 @@ def record_transition(self, self._current_dones = terminated + truncated if self.memory is not None: - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -211,10 +230,13 @@ def _update(self, timestep: int, timesteps: int) -> None: env_ids = torch.arange(self._current_rewards.shape[0]).view(-1, 1) # compute next actions - next_actions = torch.argmax(q_table[env_ids, self._current_next_states], dim=-1, keepdim=True).view(-1,1) + next_actions = torch.argmax(q_table[env_ids, self._current_next_states], dim=-1, keepdim=True).view(-1, 1) # update Q-table - q_table[env_ids, self._current_states, self._current_actions] += self._learning_rate \ - * (self._current_rewards + self._discount_factor * self._current_dones.logical_not() \ - * q_table[env_ids, self._current_next_states, next_actions] \ - - q_table[env_ids, self._current_states, self._current_actions]) + q_table[env_ids, self._current_states, self._current_actions] += self._learning_rate * ( + self._current_rewards + + self._discount_factor + * self._current_dones.logical_not() + * q_table[env_ids, self._current_next_states, next_actions] + - q_table[env_ids, self._current_states, self._current_actions] + ) diff --git a/skrl/agents/torch/rpo/rpo.py b/skrl/agents/torch/rpo/rpo.py index 25eb462b..24561945 100644 --- a/skrl/agents/torch/rpo/rpo.py +++ b/skrl/agents/torch/rpo/rpo.py @@ -68,13 +68,15 @@ class RPO(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Robust Policy Optimization (RPO) https://arxiv.org/abs/2212.07536 @@ -99,12 +101,14 @@ def __init__(self, """ _cfg = copy.deepcopy(RPO_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -159,10 +163,13 @@ def __init__(self, if self.policy is self.value: self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._learning_rate) else: - self.optimizer = torch.optim.Adam(itertools.chain(self.policy.parameters(), self.value.parameters()), - lr=self._learning_rate) + self.optimizer = torch.optim.Adam( + itertools.chain(self.policy.parameters(), self.value.parameters()), lr=self._learning_rate + ) if self._learning_rate_scheduler is not None: - self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["optimizer"] = self.optimizer @@ -180,8 +187,7 @@ def __init__(self, self._value_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -222,21 +228,25 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") # sample stochastic actions - actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states), "alpha": self._alpha}, role="policy") + actions, log_prob, outputs = self.policy.act( + {"states": self._state_preprocessor(states), "alpha": self._alpha}, role="policy" + ) self._current_log_prob = log_prob return actions, log_prob, outputs - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -258,7 +268,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: self._current_next_states = next_states @@ -268,7 +280,9 @@ def record_transition(self, rewards = self._rewards_shaper(rewards, timestep, timesteps) # compute values - values, _, _ = self.value.act({"states": self._state_preprocessor(states), "alpha": self._alpha}, role="value") + values, _, _ = self.value.act( + {"states": self._state_preprocessor(states), "alpha": self._alpha}, role="value" + ) values = self._value_preprocessor(values, inverse=True) # time-limit (truncation) boostrapping @@ -276,11 +290,27 @@ def record_transition(self, rewards += self._discount_factor * values * truncated # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -317,12 +347,15 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - def compute_gae(rewards: torch.Tensor, - dones: torch.Tensor, - values: torch.Tensor, - next_values: torch.Tensor, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> torch.Tensor: + + def compute_gae( + rewards: torch.Tensor, + dones: torch.Tensor, + values: torch.Tensor, + next_values: torch.Tensor, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, + ) -> torch.Tensor: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -349,7 +382,11 @@ def compute_gae(rewards: torch.Tensor, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else last_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] + - values[i] + + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -361,17 +398,22 @@ def compute_gae(rewards: torch.Tensor, # compute returns and advantages with torch.no_grad(): self.value.train(False) - last_values, _, _ = self.value.act({"states": self._state_preprocessor(self._current_next_states.float()), "alpha": self._alpha}, role="value") + last_values, _, _ = self.value.act( + {"states": self._state_preprocessor(self._current_next_states.float()), "alpha": self._alpha}, + role="value", + ) self.value.train(True) last_values = self._value_preprocessor(last_values, inverse=True) values = self.memory.get_tensor_by_name("values") - returns, advantages = compute_gae(rewards=self.memory.get_tensor_by_name("rewards"), - dones=self.memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor, - lambda_coefficient=self._lambda) + returns, advantages = compute_gae( + rewards=self.memory.get_tensor_by_name("rewards"), + dones=self.memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor, + lambda_coefficient=self._lambda, + ) self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True)) self.memory.set_tensor_by_name("returns", self._value_preprocessor(returns, train=True)) @@ -389,11 +431,20 @@ def compute_gae(rewards: torch.Tensor, kl_divergences = [] # mini-batches loop - for sampled_states, sampled_actions, sampled_log_prob, sampled_values, sampled_returns, sampled_advantages in sampled_batches: + for ( + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_values, + sampled_returns, + sampled_advantages, + ) in sampled_batches: sampled_states = self._state_preprocessor(sampled_states, train=not epoch) - _, next_log_prob, _ = self.policy.act({"states": sampled_states, "taken_actions": sampled_actions, "alpha": self._alpha}, role="policy") + _, next_log_prob, _ = self.policy.act( + {"states": sampled_states, "taken_actions": sampled_actions, "alpha": self._alpha}, role="policy" + ) # compute approximate KL divergence with torch.no_grad(): @@ -414,7 +465,9 @@ def compute_gae(rewards: torch.Tensor, # compute policy loss ratio = torch.exp(next_log_prob - sampled_log_prob) surrogate = sampled_advantages * ratio - surrogate_clipped = sampled_advantages * torch.clip(ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip) + surrogate_clipped = sampled_advantages * torch.clip( + ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip + ) policy_loss = -torch.min(surrogate, surrogate_clipped).mean() @@ -422,9 +475,9 @@ def compute_gae(rewards: torch.Tensor, predicted_values, _, _ = self.value.act({"states": sampled_states, "alpha": self._alpha}, role="value") if self._clip_predicted_values: - predicted_values = sampled_values + torch.clip(predicted_values - sampled_values, - min=-self._value_clip, - max=self._value_clip) + predicted_values = sampled_values + torch.clip( + predicted_values - sampled_values, min=-self._value_clip, max=self._value_clip + ) value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values) # optimization step @@ -438,7 +491,9 @@ def compute_gae(rewards: torch.Tensor, if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) else: - nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip) + nn.utils.clip_grad_norm_( + itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip + ) self.optimizer.step() # update cumulative losses @@ -463,7 +518,9 @@ def compute_gae(rewards: torch.Tensor, self.track_data("Loss / Policy loss", cumulative_policy_loss / (self._learning_epochs * self._mini_batches)) self.track_data("Loss / Value loss", cumulative_value_loss / (self._learning_epochs * self._mini_batches)) if self._entropy_loss_scale: - self.track_data("Loss / Entropy loss", cumulative_entropy_loss / (self._learning_epochs * self._mini_batches)) + self.track_data( + "Loss / Entropy loss", cumulative_entropy_loss / (self._learning_epochs * self._mini_batches) + ) self.track_data("Policy / Standard deviation", self.policy.distribution(role="policy").stddev.mean().item()) diff --git a/skrl/agents/torch/rpo/rpo_rnn.py b/skrl/agents/torch/rpo/rpo_rnn.py index c3ea7662..5cca8648 100644 --- a/skrl/agents/torch/rpo/rpo_rnn.py +++ b/skrl/agents/torch/rpo/rpo_rnn.py @@ -68,13 +68,15 @@ class RPO_RNN(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Robust Policy Optimization (RPO) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.) https://arxiv.org/abs/2212.07536 @@ -99,12 +101,14 @@ def __init__(self, """ _cfg = copy.deepcopy(RPO_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -159,10 +163,13 @@ def __init__(self, if self.policy is self.value: self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._learning_rate) else: - self.optimizer = torch.optim.Adam(itertools.chain(self.policy.parameters(), self.value.parameters()), - lr=self._learning_rate) + self.optimizer = torch.optim.Adam( + itertools.chain(self.policy.parameters(), self.value.parameters()), lr=self._learning_rate + ) if self._learning_rate_scheduler is not None: - self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.scheduler = self._learning_rate_scheduler( + self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["optimizer"] = self.optimizer @@ -180,8 +187,7 @@ def __init__(self, self._value_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -211,7 +217,9 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self._rnn = True # create tensors in memory if self.memory is not None: - self.memory.create_tensor(name=f"rnn_policy_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True) + self.memory.create_tensor( + name=f"rnn_policy_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True + ) self._rnn_tensors_names.append(f"rnn_policy_{i}") # default RNN states self._rnn_initial_states["policy"].append(torch.zeros(size, dtype=torch.float32, device=self.device)) @@ -225,7 +233,9 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self._rnn = True # create tensors in memory if self.memory is not None: - self.memory.create_tensor(name=f"rnn_value_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True) + self.memory.create_tensor( + name=f"rnn_value_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True + ) self._rnn_tensors_names.append(f"rnn_value_{i}") # default RNN states self._rnn_initial_states["value"].append(torch.zeros(size, dtype=torch.float32, device=self.device)) @@ -255,7 +265,9 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": self._state_preprocessor(states), **rnn}, role="policy") # sample stochastic actions - actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states), "alpha": self._alpha, **rnn}, role="policy") + actions, log_prob, outputs = self.policy.act( + {"states": self._state_preprocessor(states), "alpha": self._alpha, **rnn}, role="policy" + ) self._current_log_prob = log_prob if self._rnn: @@ -263,16 +275,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, log_prob, outputs - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -294,7 +308,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: self._current_next_states = next_states @@ -305,7 +321,9 @@ def record_transition(self, # compute values rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {} - values, _, outputs = self.value.act({"states": self._state_preprocessor(states), "alpha": self._alpha, **rnn}, role="value") + values, _, outputs = self.value.act( + {"states": self._state_preprocessor(states), "alpha": self._alpha, **rnn}, role="value" + ) values = self._value_preprocessor(values, inverse=True) # time-limit (truncation) boostrapping @@ -315,20 +333,44 @@ def record_transition(self, # package RNN states rnn_states = {} if self._rnn: - rnn_states.update({f"rnn_policy_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["policy"])}) + rnn_states.update( + {f"rnn_policy_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["policy"])} + ) if self.policy is not self.value: - rnn_states.update({f"rnn_value_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["value"])}) + rnn_states.update( + {f"rnn_value_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["value"])} + ) # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values, **rnn_states) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + **rnn_states, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values, **rnn_states) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + **rnn_states, + ) # update RNN states if self._rnn: - self._rnn_final_states["value"] = self._rnn_final_states["policy"] if self.policy is self.value else outputs.get("rnn", []) + self._rnn_final_states["value"] = ( + self._rnn_final_states["policy"] if self.policy is self.value else outputs.get("rnn", []) + ) # reset states if the episodes have ended finished_episodes = terminated.nonzero(as_tuple=False) @@ -376,12 +418,15 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - def compute_gae(rewards: torch.Tensor, - dones: torch.Tensor, - values: torch.Tensor, - next_values: torch.Tensor, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> torch.Tensor: + + def compute_gae( + rewards: torch.Tensor, + dones: torch.Tensor, + values: torch.Tensor, + next_values: torch.Tensor, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, + ) -> torch.Tensor: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -408,7 +453,11 @@ def compute_gae(rewards: torch.Tensor, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else last_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] + - values[i] + + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -421,28 +470,39 @@ def compute_gae(rewards: torch.Tensor, with torch.no_grad(): self.value.train(False) rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {} - last_values, _, _ = self.value.act({"states": self._state_preprocessor(self._current_next_states.float()), "alpha": self._alpha, **rnn}, role="value") + last_values, _, _ = self.value.act( + {"states": self._state_preprocessor(self._current_next_states.float()), "alpha": self._alpha, **rnn}, + role="value", + ) self.value.train(True) last_values = self._value_preprocessor(last_values, inverse=True) values = self.memory.get_tensor_by_name("values") - returns, advantages = compute_gae(rewards=self.memory.get_tensor_by_name("rewards"), - dones=self.memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor, - lambda_coefficient=self._lambda) + returns, advantages = compute_gae( + rewards=self.memory.get_tensor_by_name("rewards"), + dones=self.memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor, + lambda_coefficient=self._lambda, + ) self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True)) self.memory.set_tensor_by_name("returns", self._value_preprocessor(returns, train=True)) self.memory.set_tensor_by_name("advantages", advantages) # sample mini-batches from memory - sampled_batches = self.memory.sample_all(names=self._tensors_names, mini_batches=self._mini_batches, sequence_length=self._rnn_sequence_length) + sampled_batches = self.memory.sample_all( + names=self._tensors_names, mini_batches=self._mini_batches, sequence_length=self._rnn_sequence_length + ) rnn_policy, rnn_value = {}, {} if self._rnn: - sampled_rnn_batches = self.memory.sample_all(names=self._rnn_tensors_names, mini_batches=self._mini_batches, sequence_length=self._rnn_sequence_length) + sampled_rnn_batches = self.memory.sample_all( + names=self._rnn_tensors_names, + mini_batches=self._mini_batches, + sequence_length=self._rnn_sequence_length, + ) cumulative_policy_loss = 0 cumulative_entropy_loss = 0 @@ -453,19 +513,47 @@ def compute_gae(rewards: torch.Tensor, kl_divergences = [] # mini-batches loop - for i, (sampled_states, sampled_actions, sampled_dones, sampled_log_prob, sampled_values, sampled_returns, sampled_advantages) in enumerate(sampled_batches): + for i, ( + sampled_states, + sampled_actions, + sampled_dones, + sampled_log_prob, + sampled_values, + sampled_returns, + sampled_advantages, + ) in enumerate(sampled_batches): if self._rnn: if self.policy is self.value: - rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn_batches[i]], "terminated": sampled_dones} + rnn_policy = { + "rnn": [s.transpose(0, 1) for s in sampled_rnn_batches[i]], + "terminated": sampled_dones, + } rnn_value = rnn_policy else: - rnn_policy = {"rnn": [s.transpose(0, 1) for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) if "policy" in n], "terminated": sampled_dones} - rnn_value = {"rnn": [s.transpose(0, 1) for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) if "value" in n], "terminated": sampled_dones} + rnn_policy = { + "rnn": [ + s.transpose(0, 1) + for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) + if "policy" in n + ], + "terminated": sampled_dones, + } + rnn_value = { + "rnn": [ + s.transpose(0, 1) + for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) + if "value" in n + ], + "terminated": sampled_dones, + } sampled_states = self._state_preprocessor(sampled_states, train=not epoch) - _, next_log_prob, _ = self.policy.act({"states": sampled_states, "taken_actions": sampled_actions, "alpha": self._alpha, **rnn_policy}, role="policy") + _, next_log_prob, _ = self.policy.act( + {"states": sampled_states, "taken_actions": sampled_actions, "alpha": self._alpha, **rnn_policy}, + role="policy", + ) # compute approximate KL divergence with torch.no_grad(): @@ -486,17 +574,21 @@ def compute_gae(rewards: torch.Tensor, # compute policy loss ratio = torch.exp(next_log_prob - sampled_log_prob) surrogate = sampled_advantages * ratio - surrogate_clipped = sampled_advantages * torch.clip(ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip) + surrogate_clipped = sampled_advantages * torch.clip( + ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip + ) policy_loss = -torch.min(surrogate, surrogate_clipped).mean() # compute value loss - predicted_values, _, _ = self.value.act({"states": sampled_states, "alpha": self._alpha, **rnn_value}, role="value") + predicted_values, _, _ = self.value.act( + {"states": sampled_states, "alpha": self._alpha, **rnn_value}, role="value" + ) if self._clip_predicted_values: - predicted_values = sampled_values + torch.clip(predicted_values - sampled_values, - min=-self._value_clip, - max=self._value_clip) + predicted_values = sampled_values + torch.clip( + predicted_values - sampled_values, min=-self._value_clip, max=self._value_clip + ) value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values) # optimization step @@ -510,7 +602,9 @@ def compute_gae(rewards: torch.Tensor, if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) else: - nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip) + nn.utils.clip_grad_norm_( + itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip + ) self.optimizer.step() # update cumulative losses @@ -535,7 +629,9 @@ def compute_gae(rewards: torch.Tensor, self.track_data("Loss / Policy loss", cumulative_policy_loss / (self._learning_epochs * self._mini_batches)) self.track_data("Loss / Value loss", cumulative_value_loss / (self._learning_epochs * self._mini_batches)) if self._entropy_loss_scale: - self.track_data("Loss / Entropy loss", cumulative_entropy_loss / (self._learning_epochs * self._mini_batches)) + self.track_data( + "Loss / Entropy loss", cumulative_entropy_loss / (self._learning_epochs * self._mini_batches) + ) self.track_data("Policy / Standard deviation", self.policy.distribution(role="policy").stddev.mean().item()) diff --git a/skrl/agents/torch/sac/sac.py b/skrl/agents/torch/sac/sac.py index 4189fa85..df2a474c 100644 --- a/skrl/agents/torch/sac/sac.py +++ b/skrl/agents/torch/sac/sac.py @@ -61,13 +61,15 @@ class SAC(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Soft Actor-Critic (SAC) https://arxiv.org/abs/1801.01290 @@ -92,12 +94,14 @@ def __init__(self, """ _cfg = copy.deepcopy(SAC_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -167,7 +171,9 @@ def __init__(self, else: self._target_entropy = 0 - self.log_entropy_coefficient = torch.log(torch.ones(1, device=self.device) * self._entropy_coefficient).requires_grad_(True) + self.log_entropy_coefficient = torch.log( + torch.ones(1, device=self.device) * self._entropy_coefficient + ).requires_grad_(True) self.entropy_optimizer = torch.optim.Adam([self.log_entropy_coefficient], lr=self._entropy_learning_rate) self.checkpoint_modules["entropy_optimizer"] = self.entropy_optimizer @@ -175,11 +181,16 @@ def __init__(self, # set up optimizers and learning rate schedulers if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._actor_learning_rate) - self.critic_optimizer = torch.optim.Adam(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), - lr=self._critic_learning_rate) + self.critic_optimizer = torch.optim.Adam( + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=self._critic_learning_rate + ) if self._learning_rate_scheduler is not None: - self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) - self.critic_scheduler = self._learning_rate_scheduler(self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.policy_scheduler = self._learning_rate_scheduler( + self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) + self.critic_scheduler = self._learning_rate_scheduler( + self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer self.checkpoint_modules["critic_optimizer"] = self.critic_optimizer @@ -192,8 +203,7 @@ def __init__(self, self._state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -230,16 +240,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, None, outputs - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -261,7 +273,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: # reward shaping @@ -269,11 +283,23 @@ def record_transition(self, rewards = self._rewards_shaper(rewards, timestep, timesteps) # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -314,8 +340,9 @@ def _update(self, timestep: int, timesteps: int) -> None: for gradient_step in range(self._gradient_steps): # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = self.memory.sample( + names=self._tensors_names, batch_size=self._batch_size + )[0] sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) @@ -324,14 +351,24 @@ def _update(self, timestep: int, timesteps: int) -> None: with torch.no_grad(): next_actions, next_log_prob, _ = self.policy.act({"states": sampled_next_states}, role="policy") - target_q1_values, _, _ = self.target_critic_1.act({"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_1") - target_q2_values, _, _ = self.target_critic_2.act({"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_2") - target_q_values = torch.min(target_q1_values, target_q2_values) - self._entropy_coefficient * next_log_prob + target_q1_values, _, _ = self.target_critic_1.act( + {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_1" + ) + target_q2_values, _, _ = self.target_critic_2.act( + {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_2" + ) + target_q_values = ( + torch.min(target_q1_values, target_q2_values) - self._entropy_coefficient * next_log_prob + ) target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values # compute critic loss - critic_1_values, _, _ = self.critic_1.act({"states": sampled_states, "taken_actions": sampled_actions}, role="critic_1") - critic_2_values, _, _ = self.critic_2.act({"states": sampled_states, "taken_actions": sampled_actions}, role="critic_2") + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="critic_1" + ) + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="critic_2" + ) critic_loss = (F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values)) / 2 @@ -342,13 +379,19 @@ def _update(self, timestep: int, timesteps: int) -> None: self.critic_1.reduce_parameters() self.critic_2.reduce_parameters() if self._grad_norm_clip > 0: - nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip) + nn.utils.clip_grad_norm_( + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip + ) self.critic_optimizer.step() # compute policy (actor) loss actions, log_prob, _ = self.policy.act({"states": sampled_states}, role="policy") - critic_1_values, _, _ = self.critic_1.act({"states": sampled_states, "taken_actions": actions}, role="critic_1") - critic_2_values, _, _ = self.critic_2.act({"states": sampled_states, "taken_actions": actions}, role="critic_2") + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": actions}, role="critic_1" + ) + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": actions}, role="critic_2" + ) policy_loss = (self._entropy_coefficient * log_prob - torch.min(critic_1_values, critic_2_values)).mean() diff --git a/skrl/agents/torch/sac/sac_rnn.py b/skrl/agents/torch/sac/sac_rnn.py index 6b162958..3160a27e 100644 --- a/skrl/agents/torch/sac/sac_rnn.py +++ b/skrl/agents/torch/sac/sac_rnn.py @@ -61,13 +61,15 @@ class SAC_RNN(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Soft Actor-Critic (SAC) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.) https://arxiv.org/abs/1801.01290 @@ -92,12 +94,14 @@ def __init__(self, """ _cfg = copy.deepcopy(SAC_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -167,7 +171,9 @@ def __init__(self, else: self._target_entropy = 0 - self.log_entropy_coefficient = torch.log(torch.ones(1, device=self.device) * self._entropy_coefficient).requires_grad_(True) + self.log_entropy_coefficient = torch.log( + torch.ones(1, device=self.device) * self._entropy_coefficient + ).requires_grad_(True) self.entropy_optimizer = torch.optim.Adam([self.log_entropy_coefficient], lr=self._entropy_learning_rate) self.checkpoint_modules["entropy_optimizer"] = self.entropy_optimizer @@ -175,11 +181,16 @@ def __init__(self, # set up optimizers and learning rate schedulers if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._actor_learning_rate) - self.critic_optimizer = torch.optim.Adam(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), - lr=self._critic_learning_rate) + self.critic_optimizer = torch.optim.Adam( + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=self._critic_learning_rate + ) if self._learning_rate_scheduler is not None: - self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) - self.critic_scheduler = self._learning_rate_scheduler(self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.policy_scheduler = self._learning_rate_scheduler( + self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) + self.critic_scheduler = self._learning_rate_scheduler( + self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer self.checkpoint_modules["critic_optimizer"] = self.critic_optimizer @@ -192,8 +203,7 @@ def __init__(self, self._state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -219,7 +229,9 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self._rnn = True # create tensors in memory if self.memory is not None: - self.memory.create_tensor(name=f"rnn_policy_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True) + self.memory.create_tensor( + name=f"rnn_policy_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True + ) self._rnn_tensors_names.append(f"rnn_policy_{i}") # default RNN states self._rnn_initial_states["policy"].append(torch.zeros(size, dtype=torch.float32, device=self.device)) @@ -252,16 +264,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, None, outputs - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -283,7 +297,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: # reward shaping @@ -293,16 +309,32 @@ def record_transition(self, # package RNN states rnn_states = {} if self._rnn: - rnn_states.update({f"rnn_policy_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["policy"])}) + rnn_states.update( + {f"rnn_policy_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["policy"])} + ) # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, **rnn_states) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + **rnn_states, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, **rnn_states) - - # update RNN states + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + **rnn_states, + ) + + # update RNN states if self._rnn: # reset states if the episodes have ended finished_episodes = terminated.nonzero(as_tuple=False) @@ -351,12 +383,15 @@ def _update(self, timestep: int, timesteps: int) -> None: for gradient_step in range(self._gradient_steps): # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = self.memory.sample( + names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length + )[0] rnn_policy = {} if self._rnn: - sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] + sampled_rnn = self.memory.sample_by_index( + names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes() + )[0] rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} sampled_states = self._state_preprocessor(sampled_states, train=True) @@ -364,16 +399,28 @@ def _update(self, timestep: int, timesteps: int) -> None: # compute target values with torch.no_grad(): - next_actions, next_log_prob, _ = self.policy.act({"states": sampled_next_states, **rnn_policy}, role="policy") - - target_q1_values, _, _ = self.target_critic_1.act({"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, role="target_critic_1") - target_q2_values, _, _ = self.target_critic_2.act({"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, role="target_critic_2") - target_q_values = torch.min(target_q1_values, target_q2_values) - self._entropy_coefficient * next_log_prob + next_actions, next_log_prob, _ = self.policy.act( + {"states": sampled_next_states, **rnn_policy}, role="policy" + ) + + target_q1_values, _, _ = self.target_critic_1.act( + {"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, role="target_critic_1" + ) + target_q2_values, _, _ = self.target_critic_2.act( + {"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, role="target_critic_2" + ) + target_q_values = ( + torch.min(target_q1_values, target_q2_values) - self._entropy_coefficient * next_log_prob + ) target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values # compute critic loss - critic_1_values, _, _ = self.critic_1.act({"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_1") - critic_2_values, _, _ = self.critic_2.act({"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_2") + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_1" + ) + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_2" + ) critic_loss = (F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values)) / 2 @@ -384,13 +431,19 @@ def _update(self, timestep: int, timesteps: int) -> None: self.critic_1.reduce_parameters() self.critic_2.reduce_parameters() if self._grad_norm_clip > 0: - nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip) + nn.utils.clip_grad_norm_( + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip + ) self.critic_optimizer.step() # compute policy (actor) loss actions, log_prob, _ = self.policy.act({"states": sampled_states, **rnn_policy}, role="policy") - critic_1_values, _, _ = self.critic_1.act({"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic_1") - critic_2_values, _, _ = self.critic_2.act({"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic_2") + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic_1" + ) + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic_2" + ) policy_loss = (self._entropy_coefficient * log_prob - torch.min(critic_1_values, critic_2_values)).mean() diff --git a/skrl/agents/torch/sarsa/sarsa.py b/skrl/agents/torch/sarsa/sarsa.py index 3f079231..c7c4b7c3 100644 --- a/skrl/agents/torch/sarsa/sarsa.py +++ b/skrl/agents/torch/sarsa/sarsa.py @@ -39,13 +39,15 @@ class SARSA(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """State Action Reward State Action (SARSA) https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.17.2539 @@ -70,12 +72,14 @@ def __init__(self, """ _cfg = copy.deepcopy(SARSA_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -101,8 +105,7 @@ def __init__(self, self._current_dones = None def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tensor: @@ -125,16 +128,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens # sample actions from policy return self.policy.act({"states": states}, role="policy") - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -156,7 +161,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) # reward shaping if self._rewards_shaper is not None: @@ -169,11 +176,23 @@ def record_transition(self, self._current_dones = terminated + truncated if self.memory is not None: - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -214,7 +233,10 @@ def _update(self, timestep: int, timesteps: int) -> None: next_actions = self.policy.act({"states": self._current_next_states}, role="policy")[0] # update Q-table - q_table[env_ids, self._current_states, self._current_actions] += self._learning_rate \ - * (self._current_rewards + self._discount_factor * self._current_dones.logical_not() \ - * q_table[env_ids, self._current_next_states, next_actions] \ - - q_table[env_ids, self._current_states, self._current_actions]) + q_table[env_ids, self._current_states, self._current_actions] += self._learning_rate * ( + self._current_rewards + + self._discount_factor + * self._current_dones.logical_not() + * q_table[env_ids, self._current_next_states, next_actions] + - q_table[env_ids, self._current_states, self._current_actions] + ) diff --git a/skrl/agents/torch/td3/td3.py b/skrl/agents/torch/td3/td3.py index 3bc7930b..7e5b7862 100644 --- a/skrl/agents/torch/td3/td3.py +++ b/skrl/agents/torch/td3/td3.py @@ -66,13 +66,15 @@ class TD3(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Twin Delayed DDPG (TD3) https://arxiv.org/abs/1802.09477 @@ -97,12 +99,14 @@ def __init__(self, """ _cfg = copy.deepcopy(TD3_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -177,11 +181,16 @@ def __init__(self, # set up optimizers and learning rate schedulers if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._actor_learning_rate) - self.critic_optimizer = torch.optim.Adam(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), - lr=self._critic_learning_rate) + self.critic_optimizer = torch.optim.Adam( + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=self._critic_learning_rate + ) if self._learning_rate_scheduler is not None: - self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) - self.critic_scheduler = self._learning_rate_scheduler(self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.policy_scheduler = self._learning_rate_scheduler( + self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) + self.critic_scheduler = self._learning_rate_scheduler( + self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer self.checkpoint_modules["critic_optimizer"] = self.critic_optimizer @@ -194,8 +203,7 @@ def __init__(self, self._state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -246,9 +254,9 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens # apply exploration noise if timestep <= self._exploration_timesteps: - scale = (1 - timestep / self._exploration_timesteps) \ - * (self._exploration_initial_scale - self._exploration_final_scale) \ - + self._exploration_final_scale + scale = (1 - timestep / self._exploration_timesteps) * ( + self._exploration_initial_scale - self._exploration_final_scale + ) + self._exploration_final_scale noises.mul_(scale) # modify actions @@ -268,16 +276,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, None, outputs - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -299,7 +309,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: # reward shaping @@ -307,11 +319,23 @@ def record_transition(self, rewards = self._rewards_shaper(rewards, timestep, timesteps) # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -352,8 +376,9 @@ def _update(self, timestep: int, timesteps: int) -> None: for gradient_step in range(self._gradient_steps): # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = self.memory.sample( + names=self._tensors_names, batch_size=self._batch_size + )[0] sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) @@ -362,21 +387,31 @@ def _update(self, timestep: int, timesteps: int) -> None: # target policy smoothing next_actions, _, _ = self.target_policy.act({"states": sampled_next_states}, role="target_policy") if self._smooth_regularization_noise is not None: - noises = torch.clamp(self._smooth_regularization_noise.sample(next_actions.shape), - min=-self._smooth_regularization_clip, - max=self._smooth_regularization_clip) + noises = torch.clamp( + self._smooth_regularization_noise.sample(next_actions.shape), + min=-self._smooth_regularization_clip, + max=self._smooth_regularization_clip, + ) next_actions.add_(noises) next_actions.clamp_(min=self.clip_actions_min, max=self.clip_actions_max) # compute target values - target_q1_values, _, _ = self.target_critic_1.act({"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_1") - target_q2_values, _, _ = self.target_critic_2.act({"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_2") + target_q1_values, _, _ = self.target_critic_1.act( + {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_1" + ) + target_q2_values, _, _ = self.target_critic_2.act( + {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_2" + ) target_q_values = torch.min(target_q1_values, target_q2_values) target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values # compute critic loss - critic_1_values, _, _ = self.critic_1.act({"states": sampled_states, "taken_actions": sampled_actions}, role="critic_1") - critic_2_values, _, _ = self.critic_2.act({"states": sampled_states, "taken_actions": sampled_actions}, role="critic_2") + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="critic_1" + ) + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="critic_2" + ) critic_loss = F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values) @@ -387,7 +422,9 @@ def _update(self, timestep: int, timesteps: int) -> None: self.critic_1.reduce_parameters() self.critic_2.reduce_parameters() if self._grad_norm_clip > 0: - nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip) + nn.utils.clip_grad_norm_( + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip + ) self.critic_optimizer.step() # delayed update @@ -396,7 +433,9 @@ def _update(self, timestep: int, timesteps: int) -> None: # compute policy (actor) loss actions, _, _ = self.policy.act({"states": sampled_states}, role="policy") - critic_values, _, _ = self.critic_1.act({"states": sampled_states, "taken_actions": actions}, role="critic_1") + critic_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": actions}, role="critic_1" + ) policy_loss = -critic_values.mean() diff --git a/skrl/agents/torch/td3/td3_rnn.py b/skrl/agents/torch/td3/td3_rnn.py index 81b7f313..39d1aeed 100644 --- a/skrl/agents/torch/td3/td3_rnn.py +++ b/skrl/agents/torch/td3/td3_rnn.py @@ -66,13 +66,15 @@ class TD3_RNN(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Twin Delayed DDPG (TD3) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.) https://arxiv.org/abs/1802.09477 @@ -97,12 +99,14 @@ def __init__(self, """ _cfg = copy.deepcopy(TD3_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -177,11 +181,16 @@ def __init__(self, # set up optimizers and learning rate schedulers if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._actor_learning_rate) - self.critic_optimizer = torch.optim.Adam(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), - lr=self._critic_learning_rate) + self.critic_optimizer = torch.optim.Adam( + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=self._critic_learning_rate + ) if self._learning_rate_scheduler is not None: - self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) - self.critic_scheduler = self._learning_rate_scheduler(self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.policy_scheduler = self._learning_rate_scheduler( + self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) + self.critic_scheduler = self._learning_rate_scheduler( + self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer self.checkpoint_modules["critic_optimizer"] = self.critic_optimizer @@ -194,8 +203,7 @@ def __init__(self, self._state_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -221,7 +229,9 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self._rnn = True # create tensors in memory if self.memory is not None: - self.memory.create_tensor(name=f"rnn_policy_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True) + self.memory.create_tensor( + name=f"rnn_policy_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True + ) self._rnn_tensors_names.append(f"rnn_policy_{i}") # default RNN states self._rnn_initial_states["policy"].append(torch.zeros(size, dtype=torch.float32, device=self.device)) @@ -268,9 +278,9 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens # apply exploration noise if timestep <= self._exploration_timesteps: - scale = (1 - timestep / self._exploration_timesteps) \ - * (self._exploration_initial_scale - self._exploration_final_scale) \ - + self._exploration_final_scale + scale = (1 - timestep / self._exploration_timesteps) * ( + self._exploration_initial_scale - self._exploration_final_scale + ) + self._exploration_final_scale noises.mul_(scale) # modify actions @@ -290,16 +300,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, None, outputs - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -321,7 +333,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: # reward shaping @@ -331,14 +345,30 @@ def record_transition(self, # package RNN states rnn_states = {} if self._rnn: - rnn_states.update({f"rnn_policy_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["policy"])}) + rnn_states.update( + {f"rnn_policy_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["policy"])} + ) # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, **rnn_states) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + **rnn_states, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, **rnn_states) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + **rnn_states, + ) # update RNN states if self._rnn: @@ -389,12 +419,15 @@ def _update(self, timestep: int, timesteps: int) -> None: for gradient_step in range(self._gradient_steps): # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = self.memory.sample( + names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length + )[0] rnn_policy = {} if self._rnn: - sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] + sampled_rnn = self.memory.sample_by_index( + names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes() + )[0] rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} sampled_states = self._state_preprocessor(sampled_states, train=True) @@ -402,23 +435,35 @@ def _update(self, timestep: int, timesteps: int) -> None: with torch.no_grad(): # target policy smoothing - next_actions, _, _ = self.target_policy.act({"states": sampled_next_states, **rnn_policy}, role="target_policy") + next_actions, _, _ = self.target_policy.act( + {"states": sampled_next_states, **rnn_policy}, role="target_policy" + ) if self._smooth_regularization_noise is not None: - noises = torch.clamp(self._smooth_regularization_noise.sample(next_actions.shape), - min=-self._smooth_regularization_clip, - max=self._smooth_regularization_clip) + noises = torch.clamp( + self._smooth_regularization_noise.sample(next_actions.shape), + min=-self._smooth_regularization_clip, + max=self._smooth_regularization_clip, + ) next_actions.add_(noises) next_actions.clamp_(min=self.clip_actions_min, max=self.clip_actions_max) # compute target values - target_q1_values, _, _ = self.target_critic_1.act({"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, role="target_critic_1") - target_q2_values, _, _ = self.target_critic_2.act({"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, role="target_critic_2") + target_q1_values, _, _ = self.target_critic_1.act( + {"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, role="target_critic_1" + ) + target_q2_values, _, _ = self.target_critic_2.act( + {"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, role="target_critic_2" + ) target_q_values = torch.min(target_q1_values, target_q2_values) target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values # compute critic loss - critic_1_values, _, _ = self.critic_1.act({"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_1") - critic_2_values, _, _ = self.critic_2.act({"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_2") + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_1" + ) + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_2" + ) critic_loss = F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values) @@ -429,7 +474,9 @@ def _update(self, timestep: int, timesteps: int) -> None: self.critic_1.reduce_parameters() self.critic_2.reduce_parameters() if self._grad_norm_clip > 0: - nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip) + nn.utils.clip_grad_norm_( + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip + ) self.critic_optimizer.step() # delayed update @@ -438,7 +485,9 @@ def _update(self, timestep: int, timesteps: int) -> None: # compute policy (actor) loss actions, _, _ = self.policy.act({"states": sampled_states, **rnn_policy}, role="policy") - critic_values, _, _ = self.critic_1.act({"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic_1") + critic_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic_1" + ) policy_loss = -critic_values.mean() diff --git a/skrl/agents/torch/trpo/trpo.py b/skrl/agents/torch/trpo/trpo.py index 56757867..12e3910a 100644 --- a/skrl/agents/torch/trpo/trpo.py +++ b/skrl/agents/torch/trpo/trpo.py @@ -66,13 +66,15 @@ class TRPO(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Trust Region Policy Optimization (TRPO) https://arxiv.org/abs/1502.05477 @@ -97,12 +99,14 @@ def __init__(self, """ _cfg = copy.deepcopy(TRPO_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -157,7 +161,9 @@ def __init__(self, if self.policy is not None and self.value is not None: self.value_optimizer = torch.optim.Adam(self.value.parameters(), lr=self._value_learning_rate) if self._learning_rate_scheduler is not None: - self.value_scheduler = self._learning_rate_scheduler(self.value_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.value_scheduler = self._learning_rate_scheduler( + self.value_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["value_optimizer"] = self.value_optimizer @@ -175,8 +181,7 @@ def __init__(self, self._value_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -222,16 +227,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, log_prob, outputs - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -253,7 +260,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: self._current_next_states = next_states @@ -271,11 +280,27 @@ def record_transition(self, rewards += self._discount_factor * values * truncated # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -312,12 +337,15 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - def compute_gae(rewards: torch.Tensor, - dones: torch.Tensor, - values: torch.Tensor, - next_values: torch.Tensor, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> torch.Tensor: + + def compute_gae( + rewards: torch.Tensor, + dones: torch.Tensor, + values: torch.Tensor, + next_values: torch.Tensor, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, + ) -> torch.Tensor: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -344,7 +372,11 @@ def compute_gae(rewards: torch.Tensor, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else last_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] + - values[i] + + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -353,11 +385,9 @@ def compute_gae(rewards: torch.Tensor, return returns, advantages - def surrogate_loss(policy: Model, - states: torch.Tensor, - actions: torch.Tensor, - log_prob: torch.Tensor, - advantages: torch.Tensor) -> torch.Tensor: + def surrogate_loss( + policy: Model, states: torch.Tensor, actions: torch.Tensor, log_prob: torch.Tensor, advantages: torch.Tensor + ) -> torch.Tensor: """Compute the surrogate objective (policy loss) :param policy: Policy @@ -377,11 +407,13 @@ def surrogate_loss(policy: Model, _, new_log_prob, _ = policy.act({"states": states, "taken_actions": actions}, role="policy") return (advantages * torch.exp(new_log_prob - log_prob.detach())).mean() - def conjugate_gradient(policy: Model, - states: torch.Tensor, - b: torch.Tensor, - num_iterations: float = 10, - residual_tolerance: float = 1e-10) -> torch.Tensor: + def conjugate_gradient( + policy: Model, + states: torch.Tensor, + b: torch.Tensor, + num_iterations: float = 10, + residual_tolerance: float = 1e-10, + ) -> torch.Tensor: """Conjugate gradient algorithm to solve Ax = b using the iterative method https://en.wikipedia.org/wiki/Conjugate_gradient_method#As_an_iterative_method @@ -416,10 +448,9 @@ def conjugate_gradient(policy: Model, rr_old = rr_new return x - def fisher_vector_product(policy: Model, - states: torch.Tensor, - vector: torch.Tensor, - damping: float = 0.1) -> torch.Tensor: + def fisher_vector_product( + policy: Model, states: torch.Tensor, vector: torch.Tensor, damping: float = 0.1 + ) -> torch.Tensor: """Compute the Fisher vector product (direct method) https://www.telesens.co/2018/06/09/efficiently-computing-the-fisher-vector-product-in-trpo/ @@ -440,7 +471,9 @@ def fisher_vector_product(policy: Model, kl_gradient = torch.autograd.grad(kl, policy.parameters(), create_graph=True) flat_kl_gradient = torch.cat([gradient.view(-1) for gradient in kl_gradient]) hessian_vector_gradient = torch.autograd.grad((flat_kl_gradient * vector).sum(), policy.parameters()) - flat_hessian_vector_gradient = torch.cat([gradient.contiguous().view(-1) for gradient in hessian_vector_gradient]) + flat_hessian_vector_gradient = torch.cat( + [gradient.contiguous().view(-1) for gradient in hessian_vector_gradient] + ) return flat_hessian_vector_gradient + damping * vector def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> torch.Tensor: @@ -465,32 +498,41 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor mu_2 = policy_2.act({"states": states}, role="policy")[2]["mean_actions"] logstd_2 = policy_2.get_log_std(role="policy") - kl = logstd_1 - logstd_2 + 0.5 * (torch.square(logstd_1.exp()) + torch.square(mu_1 - mu_2)) \ - / torch.square(logstd_2.exp()) - 0.5 + kl = ( + logstd_1 + - logstd_2 + + 0.5 * (torch.square(logstd_1.exp()) + torch.square(mu_1 - mu_2)) / torch.square(logstd_2.exp()) + - 0.5 + ) return torch.sum(kl, dim=-1).mean() # compute returns and advantages with torch.no_grad(): self.value.train(False) - last_values, _, _ = self.value.act({"states": self._state_preprocessor(self._current_next_states.float())}, role="value") + last_values, _, _ = self.value.act( + {"states": self._state_preprocessor(self._current_next_states.float())}, role="value" + ) self.value.train(True) last_values = self._value_preprocessor(last_values, inverse=True) values = self.memory.get_tensor_by_name("values") - returns, advantages = compute_gae(rewards=self.memory.get_tensor_by_name("rewards"), - dones=self.memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor, - lambda_coefficient=self._lambda) + returns, advantages = compute_gae( + rewards=self.memory.get_tensor_by_name("rewards"), + dones=self.memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor, + lambda_coefficient=self._lambda, + ) self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True)) self.memory.set_tensor_by_name("returns", self._value_preprocessor(returns, train=True)) self.memory.set_tensor_by_name("advantages", advantages) # sample all from memory - sampled_states, sampled_actions, sampled_log_prob, sampled_advantages \ - = self.memory.sample_all(names=self._tensors_names_policy, mini_batches=1)[0] + sampled_states, sampled_actions, sampled_log_prob, sampled_advantages = self.memory.sample_all( + names=self._tensors_names_policy, mini_batches=1 + )[0] sampled_states = self._state_preprocessor(sampled_states, train=True) @@ -500,12 +542,14 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor flat_policy_loss_gradient = torch.cat([gradient.view(-1) for gradient in policy_loss_gradient]) # compute the search direction using the conjugate gradient algorithm - search_direction = conjugate_gradient(self.policy, sampled_states, flat_policy_loss_gradient.data, - num_iterations=self._conjugate_gradient_steps) + search_direction = conjugate_gradient( + self.policy, sampled_states, flat_policy_loss_gradient.data, num_iterations=self._conjugate_gradient_steps + ) # compute step size and full step - xHx = (search_direction * fisher_vector_product(self.policy, sampled_states, search_direction, self._damping)) \ - .sum(0, keepdim=True) + xHx = ( + search_direction * fisher_vector_product(self.policy, sampled_states, search_direction, self._damping) + ).sum(0, keepdim=True) step_size = torch.sqrt(2 * self._max_kl_divergence / xHx)[0] full_step = step_size * search_direction @@ -516,7 +560,7 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor expected_improvement = (flat_policy_loss_gradient * full_step).sum(0, keepdim=True) - for alpha in [self._step_fraction * 0.5 ** i for i in range(self._max_backtrack_steps)]: + for alpha in [self._step_fraction * 0.5**i for i in range(self._max_backtrack_steps)]: new_params = params + alpha * full_step vector_to_parameters(new_params, self.policy.parameters()) diff --git a/skrl/agents/torch/trpo/trpo_rnn.py b/skrl/agents/torch/trpo/trpo_rnn.py index 2f6a8e61..ff7804c6 100644 --- a/skrl/agents/torch/trpo/trpo_rnn.py +++ b/skrl/agents/torch/trpo/trpo_rnn.py @@ -66,13 +66,15 @@ class TRPO_RNN(Agent): - def __init__(self, - models: Mapping[str, Model], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, - observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + models: Mapping[str, Model], + memory: Optional[Union[Memory, Tuple[Memory]]] = None, + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Trust Region Policy Optimization (TRPO) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.) https://arxiv.org/abs/1502.05477 @@ -97,12 +99,14 @@ def __init__(self, """ _cfg = copy.deepcopy(TRPO_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(models=models, - memory=memory, - observation_space=observation_space, - action_space=action_space, - device=device, - cfg=_cfg) + super().__init__( + models=models, + memory=memory, + observation_space=observation_space, + action_space=action_space, + device=device, + cfg=_cfg, + ) # models self.policy = self.models.get("policy", None) @@ -157,7 +161,9 @@ def __init__(self, if self.policy is not None and self.value is not None: self.value_optimizer = torch.optim.Adam(self.value.parameters(), lr=self._value_learning_rate) if self._learning_rate_scheduler is not None: - self.value_scheduler = self._learning_rate_scheduler(self.value_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) + self.value_scheduler = self._learning_rate_scheduler( + self.value_optimizer, **self.cfg["learning_rate_scheduler_kwargs"] + ) self.checkpoint_modules["value_optimizer"] = self.value_optimizer @@ -175,8 +181,7 @@ def __init__(self, self._value_preprocessor = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -206,7 +211,9 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self._rnn = True # create tensors in memory if self.memory is not None: - self.memory.create_tensor(name=f"rnn_policy_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True) + self.memory.create_tensor( + name=f"rnn_policy_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True + ) self._rnn_tensors_names.append(f"rnn_policy_{i}") # default RNN states self._rnn_initial_states["policy"].append(torch.zeros(size, dtype=torch.float32, device=self.device)) @@ -220,7 +227,9 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self._rnn = True # create tensors in memory if self.memory is not None: - self.memory.create_tensor(name=f"rnn_value_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True) + self.memory.create_tensor( + name=f"rnn_value_{i}", size=(size[0], size[2]), dtype=torch.float32, keep_dimensions=True + ) self._rnn_tensors_names.append(f"rnn_value_{i}") # default RNN states self._rnn_initial_states["value"].append(torch.zeros(size, dtype=torch.float32, device=self.device)) @@ -258,16 +267,18 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return actions, log_prob, outputs - def record_transition(self, - states: torch.Tensor, - actions: torch.Tensor, - rewards: torch.Tensor, - next_states: torch.Tensor, - terminated: torch.Tensor, - truncated: torch.Tensor, - infos: Any, - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_states: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + infos: Any, + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -289,7 +300,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memory is not None: self._current_next_states = next_states @@ -310,20 +323,44 @@ def record_transition(self, # package RNN states rnn_states = {} if self._rnn: - rnn_states.update({f"rnn_policy_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["policy"])}) + rnn_states.update( + {f"rnn_policy_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["policy"])} + ) if self.policy is not self.value: - rnn_states.update({f"rnn_value_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["value"])}) + rnn_states.update( + {f"rnn_value_{i}": s.transpose(0, 1) for i, s in enumerate(self._rnn_initial_states["value"])} + ) # storage transition in memory - self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values, **rnn_states) + self.memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + **rnn_states, + ) for memory in self.secondary_memories: - memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, - terminated=terminated, truncated=truncated, log_prob=self._current_log_prob, values=values, **rnn_states) + memory.add_samples( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + log_prob=self._current_log_prob, + values=values, + **rnn_states, + ) # update RNN states if self._rnn: - self._rnn_final_states["value"] = self._rnn_final_states["policy"] if self.policy is self.value else outputs.get("rnn", []) + self._rnn_final_states["value"] = ( + self._rnn_final_states["policy"] if self.policy is self.value else outputs.get("rnn", []) + ) # reset states if the episodes have ended finished_episodes = terminated.nonzero(as_tuple=False) @@ -371,12 +408,15 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - def compute_gae(rewards: torch.Tensor, - dones: torch.Tensor, - values: torch.Tensor, - next_values: torch.Tensor, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> torch.Tensor: + + def compute_gae( + rewards: torch.Tensor, + dones: torch.Tensor, + values: torch.Tensor, + next_values: torch.Tensor, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, + ) -> torch.Tensor: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -403,7 +443,11 @@ def compute_gae(rewards: torch.Tensor, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else last_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] + - values[i] + + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -412,11 +456,9 @@ def compute_gae(rewards: torch.Tensor, return returns, advantages - def surrogate_loss(policy: Model, - states: torch.Tensor, - actions: torch.Tensor, - log_prob: torch.Tensor, - advantages: torch.Tensor) -> torch.Tensor: + def surrogate_loss( + policy: Model, states: torch.Tensor, actions: torch.Tensor, log_prob: torch.Tensor, advantages: torch.Tensor + ) -> torch.Tensor: """Compute the surrogate objective (policy loss) :param policy: Policy @@ -436,11 +478,13 @@ def surrogate_loss(policy: Model, _, new_log_prob, _ = policy.act({"states": states, "taken_actions": actions, **rnn_policy}, role="policy") return (advantages * torch.exp(new_log_prob - log_prob.detach())).mean() - def conjugate_gradient(policy: Model, - states: torch.Tensor, - b: torch.Tensor, - num_iterations: float = 10, - residual_tolerance: float = 1e-10) -> torch.Tensor: + def conjugate_gradient( + policy: Model, + states: torch.Tensor, + b: torch.Tensor, + num_iterations: float = 10, + residual_tolerance: float = 1e-10, + ) -> torch.Tensor: """Conjugate gradient algorithm to solve Ax = b using the iterative method https://en.wikipedia.org/wiki/Conjugate_gradient_method#As_an_iterative_method @@ -475,10 +519,9 @@ def conjugate_gradient(policy: Model, rr_old = rr_new return x - def fisher_vector_product(policy: Model, - states: torch.Tensor, - vector: torch.Tensor, - damping: float = 0.1) -> torch.Tensor: + def fisher_vector_product( + policy: Model, states: torch.Tensor, vector: torch.Tensor, damping: float = 0.1 + ) -> torch.Tensor: """Compute the Fisher vector product (direct method) https://www.telesens.co/2018/06/09/efficiently-computing-the-fisher-vector-product-in-trpo/ @@ -499,7 +542,9 @@ def fisher_vector_product(policy: Model, kl_gradient = torch.autograd.grad(kl, policy.parameters(), create_graph=True) flat_kl_gradient = torch.cat([gradient.view(-1) for gradient in kl_gradient]) hessian_vector_gradient = torch.autograd.grad((flat_kl_gradient * vector).sum(), policy.parameters()) - flat_hessian_vector_gradient = torch.cat([gradient.contiguous().view(-1) for gradient in hessian_vector_gradient]) + flat_hessian_vector_gradient = torch.cat( + [gradient.contiguous().view(-1) for gradient in hessian_vector_gradient] + ) return flat_hessian_vector_gradient + damping * vector def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> torch.Tensor: @@ -525,34 +570,45 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor mu_2 = policy_2.act({"states": states, **rnn_policy}, role="policy")[2]["mean_actions"] logstd_2 = policy_2.get_log_std(role="policy") - kl = logstd_1 - logstd_2 + 0.5 * (torch.square(logstd_1.exp()) + torch.square(mu_1 - mu_2)) \ - / torch.square(logstd_2.exp()) - 0.5 + kl = ( + logstd_1 + - logstd_2 + + 0.5 * (torch.square(logstd_1.exp()) + torch.square(mu_1 - mu_2)) / torch.square(logstd_2.exp()) + - 0.5 + ) return torch.sum(kl, dim=-1).mean() # compute returns and advantages with torch.no_grad(): self.value.train(False) rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {} - last_values, _, _ = self.value.act({"states": self._state_preprocessor(self._current_next_states.float()), **rnn}, role="value") + last_values, _, _ = self.value.act( + {"states": self._state_preprocessor(self._current_next_states.float()), **rnn}, role="value" + ) self.value.train(True) last_values = self._value_preprocessor(last_values, inverse=True) values = self.memory.get_tensor_by_name("values") - returns, advantages = compute_gae(rewards=self.memory.get_tensor_by_name("rewards"), - dones=self.memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor, - lambda_coefficient=self._lambda) + returns, advantages = compute_gae( + rewards=self.memory.get_tensor_by_name("rewards"), + dones=self.memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor, + lambda_coefficient=self._lambda, + ) self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True)) self.memory.set_tensor_by_name("returns", self._value_preprocessor(returns, train=True)) self.memory.set_tensor_by_name("advantages", advantages) # sample all from memory - sampled_states, sampled_actions, sampled_dones, sampled_log_prob, sampled_advantages \ - = self.memory.sample_all(names=self._tensors_names_policy, mini_batches=1, sequence_length=self._rnn_sequence_length)[0] - sampled_rnn_batches = self.memory.sample_all(names=self._rnn_tensors_names, mini_batches=1, sequence_length=self._rnn_sequence_length)[0] + sampled_states, sampled_actions, sampled_dones, sampled_log_prob, sampled_advantages = self.memory.sample_all( + names=self._tensors_names_policy, mini_batches=1, sequence_length=self._rnn_sequence_length + )[0] + sampled_rnn_batches = self.memory.sample_all( + names=self._rnn_tensors_names, mini_batches=1, sequence_length=self._rnn_sequence_length + )[0] rnn_policy = {} @@ -560,7 +616,12 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor if self.policy is self.value: rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn_batches], "terminated": sampled_dones} else: - rnn_policy = {"rnn": [s.transpose(0, 1) for s, n in zip(sampled_rnn_batches, self._rnn_tensors_names) if "policy" in n], "terminated": sampled_dones} + rnn_policy = { + "rnn": [ + s.transpose(0, 1) for s, n in zip(sampled_rnn_batches, self._rnn_tensors_names) if "policy" in n + ], + "terminated": sampled_dones, + } sampled_states = self._state_preprocessor(sampled_states, train=True) @@ -570,12 +631,14 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor flat_policy_loss_gradient = torch.cat([gradient.view(-1) for gradient in policy_loss_gradient]) # compute the search direction using the conjugate gradient algorithm - search_direction = conjugate_gradient(self.policy, sampled_states, flat_policy_loss_gradient.data, - num_iterations=self._conjugate_gradient_steps) + search_direction = conjugate_gradient( + self.policy, sampled_states, flat_policy_loss_gradient.data, num_iterations=self._conjugate_gradient_steps + ) # compute step size and full step - xHx = (search_direction * fisher_vector_product(self.policy, sampled_states, search_direction, self._damping)) \ - .sum(0, keepdim=True) + xHx = ( + search_direction * fisher_vector_product(self.policy, sampled_states, search_direction, self._damping) + ).sum(0, keepdim=True) step_size = torch.sqrt(2 * self._max_kl_divergence / xHx)[0] full_step = step_size * search_direction @@ -586,7 +649,7 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor expected_improvement = (flat_policy_loss_gradient * full_step).sum(0, keepdim=True) - for alpha in [self._step_fraction * 0.5 ** i for i in range(self._max_backtrack_steps)]: + for alpha in [self._step_fraction * 0.5**i for i in range(self._max_backtrack_steps)]: new_params = params + alpha * full_step vector_to_parameters(new_params, self.policy.parameters()) @@ -605,11 +668,17 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor self.policy.reduce_parameters() # sample mini-batches from memory - sampled_batches = self.memory.sample_all(names=self._tensors_names_value, mini_batches=self._mini_batches, sequence_length=self._rnn_sequence_length) + sampled_batches = self.memory.sample_all( + names=self._tensors_names_value, mini_batches=self._mini_batches, sequence_length=self._rnn_sequence_length + ) rnn_value = {} if self._rnn: - sampled_rnn_batches = self.memory.sample_all(names=self._rnn_tensors_names, mini_batches=self._mini_batches, sequence_length=self._rnn_sequence_length) + sampled_rnn_batches = self.memory.sample_all( + names=self._rnn_tensors_names, + mini_batches=self._mini_batches, + sequence_length=self._rnn_sequence_length, + ) cumulative_value_loss = 0 @@ -621,9 +690,19 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor if self._rnn: if self.policy is self.value: - rnn_value = {"rnn": [s.transpose(0, 1) for s in sampled_rnn_batches[i]], "terminated": sampled_dones} + rnn_value = { + "rnn": [s.transpose(0, 1) for s in sampled_rnn_batches[i]], + "terminated": sampled_dones, + } else: - rnn_value = {"rnn": [s.transpose(0, 1) for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) if "value" in n], "terminated": sampled_dones} + rnn_value = { + "rnn": [ + s.transpose(0, 1) + for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) + if "value" in n + ], + "terminated": sampled_dones, + } sampled_states = self._state_preprocessor(sampled_states, train=not epoch) diff --git a/skrl/envs/jax.py b/skrl/envs/jax.py index 2f87c756..fe8e5a24 100644 --- a/skrl/envs/jax.py +++ b/skrl/envs/jax.py @@ -1,6 +1,7 @@ # TODO: Delete this file in future releases from skrl import logger # isort: skip + logger.warning("Using `from skrl.envs.jax import ...` is deprecated and will be removed in future versions.") logger.warning(" - Import loaders using `from skrl.envs.loaders.jax import ...`") logger.warning(" - Import wrappers using `from skrl.envs.wrappers.jax import ...`") @@ -12,6 +13,6 @@ load_isaacgym_env_preview3, load_isaacgym_env_preview4, load_isaaclab_env, - load_omniverse_isaacgym_env + load_omniverse_isaacgym_env, ) from skrl.envs.wrappers.jax import MultiAgentEnvWrapper, Wrapper, wrap_env diff --git a/skrl/envs/loaders/jax/__init__.py b/skrl/envs/loaders/jax/__init__.py index a200ec9f..91c7b8f9 100644 --- a/skrl/envs/loaders/jax/__init__.py +++ b/skrl/envs/loaders/jax/__init__.py @@ -2,7 +2,7 @@ from skrl.envs.loaders.jax.isaacgym_envs import ( load_isaacgym_env_preview2, load_isaacgym_env_preview3, - load_isaacgym_env_preview4 + load_isaacgym_env_preview4, ) from skrl.envs.loaders.jax.isaaclab_envs import load_isaaclab_env from skrl.envs.loaders.jax.omniverse_isaacgym_envs import load_omniverse_isaacgym_env diff --git a/skrl/envs/loaders/torch/__init__.py b/skrl/envs/loaders/torch/__init__.py index a7a97e9a..2bd2c3c4 100644 --- a/skrl/envs/loaders/torch/__init__.py +++ b/skrl/envs/loaders/torch/__init__.py @@ -2,7 +2,7 @@ from skrl.envs.loaders.torch.isaacgym_envs import ( load_isaacgym_env_preview2, load_isaacgym_env_preview3, - load_isaacgym_env_preview4 + load_isaacgym_env_preview4, ) from skrl.envs.loaders.torch.isaaclab_envs import load_isaaclab_env from skrl.envs.loaders.torch.omniverse_isaacgym_envs import load_omniverse_isaacgym_env diff --git a/skrl/envs/loaders/torch/bidexhands_envs.py b/skrl/envs/loaders/torch/bidexhands_envs.py index 71e80dea..40de056e 100644 --- a/skrl/envs/loaders/torch/bidexhands_envs.py +++ b/skrl/envs/loaders/torch/bidexhands_envs.py @@ -26,6 +26,7 @@ def cwd(new_path: str) -> None: finally: os.chdir(current_path) + def _print_cfg(d, indent=0) -> None: """Print the environment configuration @@ -41,12 +42,14 @@ def _print_cfg(d, indent=0) -> None: print(" | " * indent + f" |-- {key}: {value}") -def load_bidexhands_env(task_name: str = "", - num_envs: Optional[int] = None, - headless: Optional[bool] = None, - cli_args: Sequence[str] = [], - bidexhands_path: str = "", - show_cfg: bool = True): +def load_bidexhands_env( + task_name: str = "", + num_envs: Optional[int] = None, + headless: Optional[bool] = None, + cli_args: Sequence[str] = [], + bidexhands_path: str = "", + show_cfg: bool = True, +): """Load a Bi-DexHands environment :param task_name: The name of the task (default: ``""``). @@ -88,7 +91,9 @@ def load_bidexhands_env(task_name: str = "", if defined: arg_index = sys.argv.index("--task") + 1 if arg_index >= len(sys.argv): - raise ValueError("No task name defined. Set the task_name parameter or use --task as command line argument") + raise ValueError( + "No task name defined. Set the task_name parameter or use --task as command line argument" + ) if task_name and task_name != sys.argv[arg_index]: logger.warning(f"Overriding task ({task_name}) with command line argument ({sys.argv[arg_index]})") # get task name from function arguments @@ -97,7 +102,9 @@ def load_bidexhands_env(task_name: str = "", sys.argv.append("--task") sys.argv.append(task_name) else: - raise ValueError("No task name defined. Set the task_name parameter or use --task as command line argument") + raise ValueError( + "No task name defined. Set the task_name parameter or use --task as command line argument" + ) # check num_envs from command line arguments defined = False diff --git a/skrl/envs/loaders/torch/isaacgym_envs.py b/skrl/envs/loaders/torch/isaacgym_envs.py index 3d1b0232..30f8288d 100644 --- a/skrl/envs/loaders/torch/isaacgym_envs.py +++ b/skrl/envs/loaders/torch/isaacgym_envs.py @@ -7,9 +7,7 @@ from skrl import logger -__all__ = ["load_isaacgym_env_preview2", - "load_isaacgym_env_preview3", - "load_isaacgym_env_preview4"] +__all__ = ["load_isaacgym_env_preview2", "load_isaacgym_env_preview3", "load_isaacgym_env_preview4"] @contextmanager @@ -28,6 +26,7 @@ def cwd(new_path: str) -> None: finally: os.chdir(current_path) + def _omegaconf_to_dict(config) -> dict: """Convert OmegaConf config to dict @@ -45,6 +44,7 @@ def _omegaconf_to_dict(config) -> dict: d[k] = _omegaconf_to_dict(v) if isinstance(v, DictConfig) else v return d + def _print_cfg(d, indent=0) -> None: """Print the environment configuration @@ -60,12 +60,14 @@ def _print_cfg(d, indent=0) -> None: print(" | " * indent + f" |-- {key}: {value}") -def load_isaacgym_env_preview2(task_name: str = "", - num_envs: Optional[int] = None, - headless: Optional[bool] = None, - cli_args: Sequence[str] = [], - isaacgymenvs_path: str = "", - show_cfg: bool = True): +def load_isaacgym_env_preview2( + task_name: str = "", + num_envs: Optional[int] = None, + headless: Optional[bool] = None, + cli_args: Sequence[str] = [], + isaacgymenvs_path: str = "", + show_cfg: bool = True, +): """Load an Isaac Gym environment (preview 2) :param task_name: The name of the task (default: ``""``). @@ -107,7 +109,9 @@ def load_isaacgym_env_preview2(task_name: str = "", if defined: arg_index = sys.argv.index("--task") + 1 if arg_index >= len(sys.argv): - raise ValueError("No task name defined. Set the task_name parameter or use --task as command line argument") + raise ValueError( + "No task name defined. Set the task_name parameter or use --task as command line argument" + ) if task_name and task_name != sys.argv[arg_index]: logger.warning(f"Overriding task ({task_name}) with command line argument ({sys.argv[arg_index]})") # get task name from function arguments @@ -116,7 +120,9 @@ def load_isaacgym_env_preview2(task_name: str = "", sys.argv.append("--task") sys.argv.append(task_name) else: - raise ValueError("No task name defined. Set the task_name parameter or use --task as command line argument") + raise ValueError( + "No task name defined. Set the task_name parameter or use --task as command line argument" + ) # check num_envs from command line arguments defined = False @@ -153,7 +159,9 @@ def load_isaacgym_env_preview2(task_name: str = "", # get isaacgym envs path from isaacgym package metadata if not isaacgymenvs_path: if not hasattr(isaacgym, "__path__"): - raise RuntimeError("isaacgym package is not installed or could not be accessed by the current Python environment") + raise RuntimeError( + "isaacgym package is not installed or could not be accessed by the current Python environment" + ) path = isaacgym.__path__ path = os.path.join(path[0], "..", "rlgpu") else: @@ -170,7 +178,9 @@ def load_isaacgym_env_preview2(task_name: str = "", status = False logger.error(f"Failed to import required packages: {e}") if not status: - raise RuntimeError(f"Path ({path}) is not valid or the isaacgym package is not installed in editable mode (pip install -e .)") + raise RuntimeError( + f"Path ({path}) is not valid or the isaacgym package is not installed in editable mode (pip install -e .)" + ) args = get_args() @@ -191,12 +201,15 @@ def load_isaacgym_env_preview2(task_name: str = "", return env -def load_isaacgym_env_preview3(task_name: str = "", - num_envs: Optional[int] = None, - headless: Optional[bool] = None, - cli_args: Sequence[str] = [], - isaacgymenvs_path: str = "", - show_cfg: bool = True): + +def load_isaacgym_env_preview3( + task_name: str = "", + num_envs: Optional[int] = None, + headless: Optional[bool] = None, + cli_args: Sequence[str] = [], + isaacgymenvs_path: str = "", + show_cfg: bool = True, +): """Load an Isaac Gym environment (preview 3) Isaac Gym benchmark environments: https://github.com/isaac-sim/IsaacGymEnvs @@ -243,14 +256,19 @@ def load_isaacgym_env_preview3(task_name: str = "", # get task name from command line arguments if defined: if task_name and task_name != arg.split("task=")[1].split(" ")[0]: - logger.warning("Overriding task name ({}) with command line argument ({})" \ - .format(task_name, arg.split("task=")[1].split(" ")[0])) + logger.warning( + "Overriding task name ({}) with command line argument ({})".format( + task_name, arg.split("task=")[1].split(" ")[0] + ) + ) # get task name from function arguments else: if task_name: sys.argv.append(f"task={task_name}") else: - raise ValueError("No task name defined. Set task_name parameter or use task= as command line argument") + raise ValueError( + "No task name defined. Set task_name parameter or use task= as command line argument" + ) # check num_envs from command line arguments defined = False @@ -261,8 +279,11 @@ def load_isaacgym_env_preview3(task_name: str = "", # get num_envs from command line arguments if defined: if num_envs is not None and num_envs != int(arg.split("num_envs=")[1].split(" ")[0]): - logger.warning("Overriding num_envs ({}) with command line argument (num_envs={})" \ - .format(num_envs, arg.split("num_envs=")[1].split(" ")[0])) + logger.warning( + "Overriding num_envs ({}) with command line argument (num_envs={})".format( + num_envs, arg.split("num_envs=")[1].split(" ")[0] + ) + ) # get num_envs from function arguments elif num_envs is not None and num_envs > 0: sys.argv.append(f"num_envs={num_envs}") @@ -276,8 +297,11 @@ def load_isaacgym_env_preview3(task_name: str = "", # get headless from command line arguments if defined: if headless is not None and str(headless).lower() != arg.split("headless=")[1].split(" ")[0].lower(): - logger.warning("Overriding headless ({}) with command line argument (headless={})" \ - .format(headless, arg.split("headless=")[1].split(" ")[0])) + logger.warning( + "Overriding headless ({}) with command line argument (headless={})".format( + headless, arg.split("headless=")[1].split(" ")[0] + ) + ) # get headless from function arguments elif headless is not None: sys.argv.append(f"headless={headless}") @@ -294,19 +318,19 @@ def load_isaacgym_env_preview3(task_name: str = "", # set omegaconf resolvers try: - OmegaConf.register_new_resolver('eq', lambda x, y: x.lower() == y.lower()) + OmegaConf.register_new_resolver("eq", lambda x, y: x.lower() == y.lower()) except Exception as e: pass try: - OmegaConf.register_new_resolver('contains', lambda x, y: x.lower() in y.lower()) + OmegaConf.register_new_resolver("contains", lambda x, y: x.lower() in y.lower()) except Exception as e: pass try: - OmegaConf.register_new_resolver('if', lambda condition, a, b: a if condition else b) + OmegaConf.register_new_resolver("if", lambda condition, a, b: a if condition else b) except Exception as e: pass try: - OmegaConf.register_new_resolver('resolve_default', lambda default, arg: default if arg == '' else arg) + OmegaConf.register_new_resolver("resolve_default", lambda default, arg: default if arg == "" else arg) except Exception as e: pass @@ -314,7 +338,7 @@ def load_isaacgym_env_preview3(task_name: str = "", config_file = "config" args = get_args_parser().parse_args() search_path = create_automatic_config_search_path(config_file, None, config_path) - hydra_object = Hydra.create_main_hydra2(task_name='load_isaacgymenv', config_search_path=search_path) + hydra_object = Hydra.create_main_hydra2(task_name="load_isaacgymenv", config_search_path=search_path) config = hydra_object.compose_config(config_file, args.overrides, run_mode=RunMode.RUN) cfg = _omegaconf_to_dict(config.task) @@ -327,28 +351,36 @@ def load_isaacgym_env_preview3(task_name: str = "", # load environment sys.path.append(isaacgymenvs_path) from tasks import isaacgym_task_map # type: ignore + try: - env = isaacgym_task_map[config.task.name](cfg=cfg, - sim_device=config.sim_device, - graphics_device_id=config.graphics_device_id, - headless=config.headless) + env = isaacgym_task_map[config.task.name]( + cfg=cfg, + sim_device=config.sim_device, + graphics_device_id=config.graphics_device_id, + headless=config.headless, + ) except TypeError as e: - env = isaacgym_task_map[config.task.name](cfg=cfg, - rl_device=config.rl_device, - sim_device=config.sim_device, - graphics_device_id=config.graphics_device_id, - headless=config.headless, - virtual_screen_capture=config.capture_video, # TODO: check - force_render=config.force_render) + env = isaacgym_task_map[config.task.name]( + cfg=cfg, + rl_device=config.rl_device, + sim_device=config.sim_device, + graphics_device_id=config.graphics_device_id, + headless=config.headless, + virtual_screen_capture=config.capture_video, # TODO: check + force_render=config.force_render, + ) return env -def load_isaacgym_env_preview4(task_name: str = "", - num_envs: Optional[int] = None, - headless: Optional[bool] = None, - cli_args: Sequence[str] = [], - isaacgymenvs_path: str = "", - show_cfg: bool = True): + +def load_isaacgym_env_preview4( + task_name: str = "", + num_envs: Optional[int] = None, + headless: Optional[bool] = None, + cli_args: Sequence[str] = [], + isaacgymenvs_path: str = "", + show_cfg: bool = True, +): """Load an Isaac Gym environment (preview 4) Isaac Gym benchmark environments: https://github.com/isaac-sim/IsaacGymEnvs diff --git a/skrl/envs/loaders/torch/isaaclab_envs.py b/skrl/envs/loaders/torch/isaaclab_envs.py index 643ddf26..9c78c150 100644 --- a/skrl/envs/loaders/torch/isaaclab_envs.py +++ b/skrl/envs/loaders/torch/isaaclab_envs.py @@ -23,11 +23,13 @@ def _print_cfg(d, indent=0) -> None: print(" | " * indent + f" |-- {key}: {value}") -def load_isaaclab_env(task_name: str = "", - num_envs: Optional[int] = None, - headless: Optional[bool] = None, - cli_args: Sequence[str] = [], - show_cfg: bool = True): +def load_isaaclab_env( + task_name: str = "", + num_envs: Optional[int] = None, + headless: Optional[bool] = None, + cli_args: Sequence[str] = [], + show_cfg: bool = True, +): """Load an Isaac Lab environment Isaac Lab: https://isaac-sim.github.io/IsaacLab @@ -76,7 +78,9 @@ def load_isaaclab_env(task_name: str = "", if defined: arg_index = sys.argv.index("--task") + 1 if arg_index >= len(sys.argv): - raise ValueError("No task name defined. Set the task_name parameter or use --task as command line argument") + raise ValueError( + "No task name defined. Set the task_name parameter or use --task as command line argument" + ) if task_name and task_name != sys.argv[arg_index]: logger.warning(f"Overriding task ({task_name}) with command line argument ({sys.argv[arg_index]})") # get task name from function arguments @@ -85,7 +89,9 @@ def load_isaaclab_env(task_name: str = "", sys.argv.append("--task") sys.argv.append(task_name) else: - raise ValueError("No task name defined. Set the task_name parameter or use --task as command line argument") + raise ValueError( + "No task name defined. Set the task_name parameter or use --task as command line argument" + ) # check num_envs from command line arguments defined = False @@ -125,8 +131,12 @@ def load_isaaclab_env(task_name: str = "", parser.add_argument("--task", type=str, default=None, help="Name of the task.") parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment") parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.") - parser.add_argument("--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations.") - parser.add_argument("--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes.") + parser.add_argument( + "--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations." + ) + parser.add_argument( + "--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes." + ) # launch the simulation app from omni.isaac.lab.app import AppLauncher diff --git a/skrl/envs/loaders/torch/omniverse_isaacgym_envs.py b/skrl/envs/loaders/torch/omniverse_isaacgym_envs.py index 350c42c4..3d795c95 100644 --- a/skrl/envs/loaders/torch/omniverse_isaacgym_envs.py +++ b/skrl/envs/loaders/torch/omniverse_isaacgym_envs.py @@ -27,6 +27,7 @@ def _omegaconf_to_dict(config) -> dict: d[k] = _omegaconf_to_dict(v) if isinstance(v, DictConfig) else v return d + def _print_cfg(d, indent=0) -> None: """Print the environment configuration @@ -41,14 +42,17 @@ def _print_cfg(d, indent=0) -> None: else: print(" | " * indent + f" |-- {key}: {value}") -def load_omniverse_isaacgym_env(task_name: str = "", - num_envs: Optional[int] = None, - headless: Optional[bool] = None, - cli_args: Sequence[str] = [], - omniisaacgymenvs_path: str = "", - show_cfg: bool = True, - multi_threaded: bool = False, - timeout: int = 30) -> Union["VecEnvBase", "VecEnvMT"]: + +def load_omniverse_isaacgym_env( + task_name: str = "", + num_envs: Optional[int] = None, + headless: Optional[bool] = None, + cli_args: Sequence[str] = [], + omniisaacgymenvs_path: str = "", + show_cfg: bool = True, + multi_threaded: bool = False, + timeout: int = 30, +) -> Union["VecEnvBase", "VecEnvMT"]: """Load an Omniverse Isaac Gym environment (OIGE) Omniverse Isaac Gym benchmark environments: https://github.com/isaac-sim/OmniIsaacGymEnvs @@ -103,14 +107,19 @@ def load_omniverse_isaacgym_env(task_name: str = "", # get task name from command line arguments if defined: if task_name and task_name != arg.split("task=")[1].split(" ")[0]: - logger.warning("Overriding task name ({}) with command line argument (task={})" \ - .format(task_name, arg.split("task=")[1].split(" ")[0])) + logger.warning( + "Overriding task name ({}) with command line argument (task={})".format( + task_name, arg.split("task=")[1].split(" ")[0] + ) + ) # get task name from function arguments else: if task_name: sys.argv.append(f"task={task_name}") else: - raise ValueError("No task name defined. Set task_name parameter or use task= as command line argument") + raise ValueError( + "No task name defined. Set task_name parameter or use task= as command line argument" + ) # check num_envs from command line arguments defined = False @@ -121,8 +130,11 @@ def load_omniverse_isaacgym_env(task_name: str = "", # get num_envs from command line arguments if defined: if num_envs is not None and num_envs != int(arg.split("num_envs=")[1].split(" ")[0]): - logger.warning("Overriding num_envs ({}) with command line argument (num_envs={})" \ - .format(num_envs, arg.split("num_envs=")[1].split(" ")[0])) + logger.warning( + "Overriding num_envs ({}) with command line argument (num_envs={})".format( + num_envs, arg.split("num_envs=")[1].split(" ")[0] + ) + ) # get num_envs from function arguments elif num_envs is not None and num_envs > 0: sys.argv.append(f"num_envs={num_envs}") @@ -136,8 +148,11 @@ def load_omniverse_isaacgym_env(task_name: str = "", # get headless from command line arguments if defined: if headless is not None and str(headless).lower() != arg.split("headless=")[1].split(" ")[0].lower(): - logger.warning("Overriding headless ({}) with command line argument (headless={})" \ - .format(headless, arg.split("headless=")[1].split(" ")[0])) + logger.warning( + "Overriding headless ({}) with command line argument (headless={})".format( + headless, arg.split("headless=")[1].split(" ")[0] + ) + ) # get headless from function arguments elif headless is not None: sys.argv.append(f"headless={headless}") @@ -153,16 +168,16 @@ def load_omniverse_isaacgym_env(task_name: str = "", config_path = os.path.join(omniisaacgymenvs_path, "cfg") # set omegaconf resolvers - OmegaConf.register_new_resolver('eq', lambda x, y: x.lower() == y.lower()) - OmegaConf.register_new_resolver('contains', lambda x, y: x.lower() in y.lower()) - OmegaConf.register_new_resolver('if', lambda condition, a, b: a if condition else b) - OmegaConf.register_new_resolver('resolve_default', lambda default, arg: default if arg == '' else arg) + OmegaConf.register_new_resolver("eq", lambda x, y: x.lower() == y.lower()) + OmegaConf.register_new_resolver("contains", lambda x, y: x.lower() in y.lower()) + OmegaConf.register_new_resolver("if", lambda condition, a, b: a if condition else b) + OmegaConf.register_new_resolver("resolve_default", lambda default, arg: default if arg == "" else arg) # get hydra config without use @hydra.main config_file = "config" args = get_args_parser().parse_args() search_path = create_automatic_config_search_path(config_file, None, config_path) - hydra_object = Hydra.create_main_hydra2(task_name='load_omniisaacgymenv', config_search_path=search_path) + hydra_object = Hydra.create_main_hydra2(task_name="load_omniisaacgymenv", config_search_path=search_path) config = hydra_object.compose_config(config_file, args.overrides, run_mode=RunMode.RUN) del config.hydra @@ -177,7 +192,9 @@ def load_omniverse_isaacgym_env(task_name: str = "", # internal classes class _OmniIsaacGymVecEnv(VecEnvBase): def step(self, actions): - actions = torch.clamp(actions, -self._task.clip_actions, self._task.clip_actions).to(self._task.device).clone() + actions = ( + torch.clamp(actions, -self._task.clip_actions, self._task.clip_actions).to(self._task.device).clone() + ) self._task.pre_physics_step(actions) for _ in range(self._task.control_frequency_inv): @@ -186,8 +203,16 @@ def step(self, actions): observations, rewards, dones, info = self._task.post_physics_step() - return {"obs": torch.clamp(observations, -self._task.clip_obs, self._task.clip_obs).to(self._task.rl_device).clone()}, \ - rewards.to(self._task.rl_device).clone(), dones.to(self._task.rl_device).clone(), info.copy() + return ( + { + "obs": torch.clamp(observations, -self._task.clip_obs, self._task.clip_obs) + .to(self._task.rl_device) + .clone() + }, + rewards.to(self._task.rl_device).clone(), + dones.to(self._task.rl_device).clone(), + info.copy(), + ) def reset(self): self._task.reset() @@ -212,7 +237,9 @@ def run(self, trainer=None): super().run(_OmniIsaacGymTrainerMT() if trainer is None else trainer) def _parse_data(self, data): - self._observations = torch.clamp(data["obs"], -self._task.clip_obs, self._task.clip_obs).to(self._task.rl_device).clone() + self._observations = ( + torch.clamp(data["obs"], -self._task.clip_obs, self._task.clip_obs).to(self._task.rl_device).clone() + ) self._rewards = data["rew"].to(self._task.rl_device).clone() self._dones = data["reset"].to(self._task.rl_device).clone() self._info = data["extras"].copy() @@ -253,10 +280,12 @@ def close(self): if multi_threaded: try: - env = _OmniIsaacGymVecEnvMT(headless=config.headless, - sim_device=config.device_id, - enable_livestream=config.enable_livestream, - enable_viewport=enable_viewport) + env = _OmniIsaacGymVecEnvMT( + headless=config.headless, + sim_device=config.device_id, + enable_livestream=config.enable_livestream, + enable_viewport=enable_viewport, + ) except (TypeError, omegaconf.errors.ConfigAttributeError): logger.warning("Using an older version of Isaac Sim or OmniIsaacGymEnvs (2022.2.0 or earlier)") env = _OmniIsaacGymVecEnvMT(headless=config.headless) # Isaac Sim 2022.2.0 and earlier @@ -264,10 +293,12 @@ def close(self): env.initialize(env.action_queue, env.data_queue, timeout=timeout) else: try: - env = _OmniIsaacGymVecEnv(headless=config.headless, - sim_device=config.device_id, - enable_livestream=config.enable_livestream, - enable_viewport=enable_viewport) + env = _OmniIsaacGymVecEnv( + headless=config.headless, + sim_device=config.device_id, + enable_livestream=config.enable_livestream, + enable_viewport=enable_viewport, + ) except (TypeError, omegaconf.errors.ConfigAttributeError): logger.warning("Using an older version of Isaac Sim or OmniIsaacGymEnvs (2022.2.0 or earlier)") env = _OmniIsaacGymVecEnv(headless=config.headless) # Isaac Sim 2022.2.0 and earlier diff --git a/skrl/envs/torch.py b/skrl/envs/torch.py index d5922a7f..d9c95eb6 100644 --- a/skrl/envs/torch.py +++ b/skrl/envs/torch.py @@ -1,6 +1,7 @@ # TODO: Delete this file in future releases from skrl import logger # isort: skip + logger.warning("Using `from skrl.envs.torch import ...` is deprecated and will be removed in future versions.") logger.warning(" - Import loaders using `from skrl.envs.loaders.torch import ...`") logger.warning(" - Import wrappers using `from skrl.envs.wrappers.torch import ...`") @@ -12,6 +13,6 @@ load_isaacgym_env_preview3, load_isaacgym_env_preview4, load_isaaclab_env, - load_omniverse_isaacgym_env + load_omniverse_isaacgym_env, ) from skrl.envs.wrappers.torch import MultiAgentEnvWrapper, Wrapper, wrap_env diff --git a/skrl/envs/wrappers/jax/__init__.py b/skrl/envs/wrappers/jax/__init__.py index 48250a98..3fbf4e42 100644 --- a/skrl/envs/wrappers/jax/__init__.py +++ b/skrl/envs/wrappers/jax/__init__.py @@ -75,6 +75,7 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True) -> Union[Wra :return: Wrapped environment :rtype: Wrapper or MultiAgentEnvWrapper """ + def _get_wrapper_name(env, verbose): def _in(values, container): if type(values) == str: @@ -87,7 +88,9 @@ def _in(values, container): base_classes = [str(base).replace("", "") for base in env.__class__.__bases__] try: - base_classes += [str(base).replace("", "") for base in env.unwrapped.__class__.__bases__] + base_classes += [ + str(base).replace("", "") for base in env.unwrapped.__class__.__bases__ + ] except: pass base_classes = sorted(list(set(base_classes))) diff --git a/skrl/envs/wrappers/jax/base.py b/skrl/envs/wrappers/jax/base.py index a7a2bc2d..e5ea699e 100644 --- a/skrl/envs/wrappers/jax/base.py +++ b/skrl/envs/wrappers/jax/base.py @@ -27,7 +27,7 @@ def __init__(self, env: Any) -> None: self._device = None if hasattr(self._unwrapped, "device"): if type(self._unwrapped.device) == str: - device_type, device_index = f"{self._unwrapped.device}:0".split(':')[:2] + device_type, device_index = f"{self._unwrapped.device}:0".split(":")[:2] try: self._device = jax.devices(device_type)[int(device_index)] except (RuntimeError, IndexError): @@ -52,7 +52,9 @@ def __getattr__(self, key: str) -> Any: return getattr(self._env, key) if hasattr(self._unwrapped, key): return getattr(self._unwrapped, key) - raise AttributeError(f"Wrapped environment ({self._unwrapped.__class__.__name__}) does not have attribute '{key}'") + raise AttributeError( + f"Wrapped environment ({self._unwrapped.__class__.__name__}) does not have attribute '{key}'" + ) def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: """Reset the environment @@ -64,9 +66,13 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: """ raise NotImplementedError - def step(self, actions: Union[np.ndarray, jax.Array]) -> \ - Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], - Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Any]: + def step(self, actions: Union[np.ndarray, jax.Array]) -> Tuple[ + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Any, + ]: """Perform a step in the environment :param actions: The actions to perform @@ -141,14 +147,12 @@ def state_space(self) -> Union[gymnasium.Space, None]: @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" return self._unwrapped.observation_space @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" return self._unwrapped.action_space @@ -171,7 +175,7 @@ def __init__(self, env: Any) -> None: self._device = None if hasattr(self._unwrapped, "device"): if type(self._unwrapped.device) == str: - device_type, device_index = f"{self._unwrapped.device}:0".split(':')[:2] + device_type, device_index = f"{self._unwrapped.device}:0".split(":")[:2] try: self._device = jax.devices(device_type)[int(device_index)] except (RuntimeError, IndexError): @@ -196,7 +200,9 @@ def __getattr__(self, key: str) -> Any: return getattr(self._env, key) if hasattr(self._unwrapped, key): return getattr(self._unwrapped, key) - raise AttributeError(f"Wrapped environment ({self._unwrapped.__class__.__name__}) does not have attribute '{key}'") + raise AttributeError( + f"Wrapped environment ({self._unwrapped.__class__.__name__}) does not have attribute '{key}'" + ) def reset(self) -> Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Any]]: """Reset the environment @@ -208,10 +214,13 @@ def reset(self) -> Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str """ raise NotImplementedError - def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \ - Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]], - Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]], - Mapping[str, Any]]: + def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> Tuple[ + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Any], + ]: """Perform a step in the environment :param actions: The actions to perform @@ -319,14 +328,12 @@ def state_spaces(self) -> Mapping[str, gymnasium.Space]: @property def observation_spaces(self) -> Mapping[str, gymnasium.Space]: - """Observation spaces - """ + """Observation spaces""" return self._unwrapped.observation_spaces @property def action_spaces(self) -> Mapping[str, gymnasium.Space]: - """Action spaces - """ + """Action spaces""" return self._unwrapped.action_spaces def state_space(self, agent: str) -> gymnasium.Space: diff --git a/skrl/envs/wrappers/jax/bidexhands_envs.py b/skrl/envs/wrappers/jax/bidexhands_envs.py index e292047c..b63549a9 100644 --- a/skrl/envs/wrappers/jax/bidexhands_envs.py +++ b/skrl/envs/wrappers/jax/bidexhands_envs.py @@ -21,14 +21,18 @@ # jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: DLPack tensor is on GPU, but no GPU backend was provided. _CPU = jax.devices()[0].device_kind.lower() == "cpu" + def _jax2torch(array, device, from_jax=True): if from_jax: return torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(array)).to(device=device) return torch.tensor(array, device=device) + def _torch2jax(tensor, to_jax=True): if to_jax: - return jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(tensor.contiguous().cpu() if _CPU else tensor.contiguous())) + return jax_dlpack.from_dlpack( + torch_dlpack.to_dlpack(tensor.contiguous().cpu() if _CPU else tensor.contiguous()) + ) return tensor.cpu().numpy() @@ -70,24 +74,27 @@ def state_spaces(self) -> Mapping[str, gymnasium.Space]: this property returns a dictionary (for consistency with the other space-related properties) with the same space for all the agents """ - return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.share_observation_space)} + return { + uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.share_observation_space) + } @property def observation_spaces(self) -> Mapping[str, gymnasium.Space]: - """Observation spaces - """ + """Observation spaces""" return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.observation_space)} @property def action_spaces(self) -> Mapping[str, gymnasium.Space]: - """Action spaces - """ + """Action spaces""" return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.action_space)} - def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \ - Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]], - Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]], - Mapping[str, Any]]: + def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> Tuple[ + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Any], + ]: """Perform a step in the environment :param actions: The actions to perform @@ -107,9 +114,9 @@ def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \ terminated = _torch2jax(terminated.to(dtype=torch.int8), self._jax) self._states = states[:, 0] - self._observations = {uid: observations[:,i] for i, uid in enumerate(self.possible_agents)} - rewards = {uid: rewards[:,i].reshape(-1, 1) for i, uid in enumerate(self.possible_agents)} - terminated = {uid: terminated[:,i].reshape(-1, 1) for i, uid in enumerate(self.possible_agents)} + self._observations = {uid: observations[:, i] for i, uid in enumerate(self.possible_agents)} + rewards = {uid: rewards[:, i].reshape(-1, 1) for i, uid in enumerate(self.possible_agents)} + terminated = {uid: terminated[:, i].reshape(-1, 1) for i, uid in enumerate(self.possible_agents)} truncated = terminated return self._observations, rewards, terminated, truncated, self._info @@ -135,16 +142,14 @@ def reset(self) -> Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str states = _torch2jax(states, self._jax) self._states = states[:, 0] - self._observations = {uid: observations[:,i] for i, uid in enumerate(self.possible_agents)} + self._observations = {uid: observations[:, i] for i, uid in enumerate(self.possible_agents)} self._reset_once = False return self._observations, self._info def render(self, *args, **kwargs) -> None: - """Render the environment - """ + """Render the environment""" return None def close(self) -> None: - """Close the environment - """ + """Close the environment""" pass diff --git a/skrl/envs/wrappers/jax/brax_envs.py b/skrl/envs/wrappers/jax/brax_envs.py index 8fa10a29..68054cee 100644 --- a/skrl/envs/wrappers/jax/brax_envs.py +++ b/skrl/envs/wrappers/jax/brax_envs.py @@ -12,7 +12,7 @@ convert_gym_space, flatten_tensorized_space, tensorize_space, - unflatten_tensorized_space + unflatten_tensorized_space, ) @@ -26,25 +26,28 @@ def __init__(self, env: Any) -> None: super().__init__(env) import brax.envs.wrappers.gym + env = brax.envs.wrappers.gym.VectorGymWrapper(env) self._env = env self._unwrapped = env.unwrapped @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" return convert_gym_space(self._unwrapped.observation_space, squeeze_batch_dimension=True) @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" return convert_gym_space(self._unwrapped.action_space, squeeze_batch_dimension=True) - def step(self, actions: Union[np.ndarray, jax.Array]) -> \ - Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], - Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Any]: + def step(self, actions: Union[np.ndarray, jax.Array]) -> Tuple[ + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Any, + ]: """Perform a step in the environment :param actions: The actions to perform @@ -76,13 +79,13 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: return observation, {} def render(self, *args, **kwargs) -> None: - """Render the environment - """ + """Render the environment""" frame = self._env.render(mode="rgb_array") # render the frame using OpenCV try: import cv2 + cv2.imshow("env", cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) cv2.waitKey(1) except ImportError as e: @@ -90,7 +93,6 @@ def render(self, *args, **kwargs) -> None: return frame def close(self) -> None: - """Close the environment - """ + """Close the environment""" # self._env.close() raises AttributeError: 'VectorGymWrapper' object has no attribute 'closed' pass diff --git a/skrl/envs/wrappers/jax/gym_envs.py b/skrl/envs/wrappers/jax/gym_envs.py index 412a6435..3fd6e56b 100644 --- a/skrl/envs/wrappers/jax/gym_envs.py +++ b/skrl/envs/wrappers/jax/gym_envs.py @@ -13,7 +13,7 @@ flatten_tensorized_space, tensorize_space, unflatten_tensorized_space, - untensorize_space + untensorize_space, ) @@ -33,6 +33,7 @@ def __init__(self, env: Any) -> None: np.bool8 = np.bool import gym + self._vectorized = False try: if isinstance(env, gym.vector.VectorEnv): @@ -49,23 +50,25 @@ def __init__(self, env: Any) -> None: @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" if self._vectorized: return convert_gym_space(self._env.single_observation_space) return convert_gym_space(self._env.observation_space) @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" if self._vectorized: return convert_gym_space(self._env.single_action_space) return convert_gym_space(self._env.action_space) - def step(self, actions: Union[np.ndarray, jax.Array]) -> \ - Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], - Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Any]: + def step(self, actions: Union[np.ndarray, jax.Array]) -> Tuple[ + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Any, + ]: """Perform a step in the environment :param actions: The actions to perform @@ -76,9 +79,11 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \ """ if self._jax or isinstance(actions, jax.Array): actions = np.asarray(jax.device_get(actions)) - actions = untensorize_space(self.action_space, - unflatten_tensorized_space(self.action_space, actions), - squeeze_batch_dimension=not self._vectorized) + actions = untensorize_space( + self.action_space, + unflatten_tensorized_space(self.action_space, actions), + squeeze_batch_dimension=not self._vectorized, + ) if self._deprecated_api: observation, reward, terminated, info = self._env.step(actions) @@ -94,7 +99,9 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \ observation, reward, terminated, truncated, info = self._env.step(actions) # convert response to numpy or jax - observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device, False), False) + observation = flatten_tensorized_space( + tensorize_space(self.observation_space, observation, self.device, False), False + ) reward = np.array(reward, dtype=np.float32).reshape(self.num_envs, -1) terminated = np.array(terminated, dtype=np.int8).reshape(self.num_envs, -1) truncated = np.array(truncated, dtype=np.int8).reshape(self.num_envs, -1) @@ -125,7 +132,9 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: self._info = {} else: observation, self._info = self._env.reset() - self._observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device, False), False) + self._observation = flatten_tensorized_space( + tensorize_space(self.observation_space, observation, self.device, False), False + ) if self._jax: self._observation = jax.device_put(self._observation, device=self.device) self._reset_once = False @@ -138,19 +147,19 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: observation, info = self._env.reset() # convert response to numpy or jax - observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device, False), False) + observation = flatten_tensorized_space( + tensorize_space(self.observation_space, observation, self.device, False), False + ) if self._jax: observation = jax.device_put(observation, device=self.device) return observation, info def render(self, *args, **kwargs) -> Any: - """Render the environment - """ + """Render the environment""" if self._vectorized: return None return self._env.render(*args, **kwargs) def close(self) -> None: - """Close the environment - """ + """Close the environment""" self._env.close() diff --git a/skrl/envs/wrappers/jax/gymnasium_envs.py b/skrl/envs/wrappers/jax/gymnasium_envs.py index b8edd4fe..9f836fc3 100644 --- a/skrl/envs/wrappers/jax/gymnasium_envs.py +++ b/skrl/envs/wrappers/jax/gymnasium_envs.py @@ -11,7 +11,7 @@ flatten_tensorized_space, tensorize_space, unflatten_tensorized_space, - untensorize_space + untensorize_space, ) @@ -36,23 +36,25 @@ def __init__(self, env: Any) -> None: @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" if self._vectorized: return self._env.single_observation_space return self._env.observation_space @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" if self._vectorized: return self._env.single_action_space return self._env.action_space - def step(self, actions: Union[np.ndarray, jax.Array]) -> \ - Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], - Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Any]: + def step(self, actions: Union[np.ndarray, jax.Array]) -> Tuple[ + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Any, + ]: """Perform a step in the environment :param actions: The actions to perform @@ -63,14 +65,18 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \ """ if self._jax or isinstance(actions, jax.Array): actions = np.asarray(jax.device_get(actions)) - actions = untensorize_space(self.action_space, - unflatten_tensorized_space(self.action_space, actions), - squeeze_batch_dimension=not self._vectorized) + actions = untensorize_space( + self.action_space, + unflatten_tensorized_space(self.action_space, actions), + squeeze_batch_dimension=not self._vectorized, + ) observation, reward, terminated, truncated, info = self._env.step(actions) # convert response to numpy or jax - observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device, False), False) + observation = flatten_tensorized_space( + tensorize_space(self.observation_space, observation, self.device, False), False + ) reward = np.array(reward, dtype=np.float32).reshape(self.num_envs, -1) terminated = np.array(terminated, dtype=np.int8).reshape(self.num_envs, -1) truncated = np.array(truncated, dtype=np.int8).reshape(self.num_envs, -1) @@ -97,7 +103,9 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: if self._vectorized: if self._reset_once: observation, self._info = self._env.reset() - self._observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device, False), False) + self._observation = flatten_tensorized_space( + tensorize_space(self.observation_space, observation, self.device, False), False + ) if self._jax: self._observation = jax.device_put(self._observation, device=self.device) self._reset_once = False @@ -106,19 +114,19 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: observation, info = self._env.reset() # convert response to numpy or jax - observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device, False), False) + observation = flatten_tensorized_space( + tensorize_space(self.observation_space, observation, self.device, False), False + ) if self._jax: observation = jax.device_put(observation, device=self.device) return observation, info def render(self, *args, **kwargs) -> Any: - """Render the environment - """ + """Render the environment""" if self._vectorized: return self._env.call("render", *args, **kwargs) return self._env.render(*args, **kwargs) def close(self) -> None: - """Close the environment - """ + """Close the environment""" self._env.close() diff --git a/skrl/envs/wrappers/jax/isaacgym_envs.py b/skrl/envs/wrappers/jax/isaacgym_envs.py index 5136de8a..43df3a23 100644 --- a/skrl/envs/wrappers/jax/isaacgym_envs.py +++ b/skrl/envs/wrappers/jax/isaacgym_envs.py @@ -17,7 +17,7 @@ convert_gym_space, flatten_tensorized_space, tensorize_space, - unflatten_tensorized_space + unflatten_tensorized_space, ) from skrl import logger @@ -30,14 +30,18 @@ if _CPU: logger.warning("IsaacGymEnvs runs on GPU, but there is no GPU backend for JAX. JAX operations will run on CPU.") + def _jax2torch(array, device, from_jax=True): if from_jax: return torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(array)).to(device=device) return torch.tensor(array, device=device) + def _torch2jax(tensor, to_jax=True): if to_jax: - return jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(tensor.contiguous().cpu() if _CPU else tensor.contiguous())) + return jax_dlpack.from_dlpack( + torch_dlpack.to_dlpack(tensor.contiguous().cpu() if _CPU else tensor.contiguous()) + ) return tensor.cpu().numpy() @@ -56,19 +60,21 @@ def __init__(self, env: Any) -> None: @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" return convert_gym_space(self._unwrapped.observation_space) @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" return convert_gym_space(self._unwrapped.action_space) - def step(self, actions: Union[np.ndarray, jax.Array]) -> \ - Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], - Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Any]: + def step(self, actions: Union[np.ndarray, jax.Array]) -> Tuple[ + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Any, + ]: """Perform a step in the environment :param actions: The actions to perform @@ -80,18 +86,24 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \ actions = _jax2torch(actions, self._env.device, self._jax) with torch.no_grad(): - observations, reward, terminated, self._info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) + observations, reward, terminated, self._info = self._env.step( + unflatten_tensorized_space(self.action_space, actions) + ) observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations)) terminated = terminated.to(dtype=torch.int8) - truncated = self._info["time_outs"].to(dtype=torch.int8) if "time_outs" in self._info else torch.zeros_like(terminated) + truncated = ( + self._info["time_outs"].to(dtype=torch.int8) if "time_outs" in self._info else torch.zeros_like(terminated) + ) self._observations = _torch2jax(observations, self._jax) - return self._observations, \ - _torch2jax(reward.view(-1, 1), self._jax), \ - _torch2jax(terminated.view(-1, 1), self._jax), \ - _torch2jax(truncated.view(-1, 1), self._jax), \ - self._info + return ( + self._observations, + _torch2jax(reward.view(-1, 1), self._jax), + _torch2jax(terminated.view(-1, 1), self._jax), + _torch2jax(truncated.view(-1, 1), self._jax), + self._info, + ) def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: """Reset the environment @@ -107,13 +119,11 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: return self._observations, self._info def render(self, *args, **kwargs) -> None: - """Render the environment - """ + """Render the environment""" return None def close(self) -> None: - """Close the environment - """ + """Close the environment""" pass @@ -132,20 +142,17 @@ def __init__(self, env: Any) -> None: @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" return convert_gym_space(self._unwrapped.observation_space) @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" return convert_gym_space(self._unwrapped.action_space) @property def state_space(self) -> Union[gymnasium.Space, None]: - """State space - """ + """State space""" try: if self.num_states: return convert_gym_space(self._unwrapped.state_space) @@ -153,9 +160,13 @@ def state_space(self) -> Union[gymnasium.Space, None]: pass return None - def step(self, actions: Union[np.ndarray, jax.Array]) ->\ - Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], - Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Any]: + def step(self, actions: Union[np.ndarray, jax.Array]) -> Tuple[ + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Any, + ]: """Perform a step in the environment :param actions: The actions to perform @@ -167,18 +178,24 @@ def step(self, actions: Union[np.ndarray, jax.Array]) ->\ actions = _jax2torch(actions, self._env.device, self._jax) with torch.no_grad(): - observations, reward, terminated, self._info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) + observations, reward, terminated, self._info = self._env.step( + unflatten_tensorized_space(self.action_space, actions) + ) observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) terminated = terminated.to(dtype=torch.int8) - truncated = self._info["time_outs"].to(dtype=torch.int8) if "time_outs" in self._info else torch.zeros_like(terminated) + truncated = ( + self._info["time_outs"].to(dtype=torch.int8) if "time_outs" in self._info else torch.zeros_like(terminated) + ) self._observations = _torch2jax(observations, self._jax) - return self._observations, \ - _torch2jax(reward.view(-1, 1), self._jax), \ - _torch2jax(terminated.view(-1, 1), self._jax), \ - _torch2jax(truncated.view(-1, 1), self._jax), \ - self._info + return ( + self._observations, + _torch2jax(reward.view(-1, 1), self._jax), + _torch2jax(terminated.view(-1, 1), self._jax), + _torch2jax(truncated.view(-1, 1), self._jax), + self._info, + ) def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: """Reset the environment @@ -194,11 +211,9 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: return self._observations, self._info def render(self, *args, **kwargs) -> None: - """Render the environment - """ + """Render the environment""" return None def close(self) -> None: - """Close the environment - """ + """Close the environment""" pass diff --git a/skrl/envs/wrappers/jax/isaaclab_envs.py b/skrl/envs/wrappers/jax/isaaclab_envs.py index 6ce151e7..39584516 100644 --- a/skrl/envs/wrappers/jax/isaaclab_envs.py +++ b/skrl/envs/wrappers/jax/isaaclab_envs.py @@ -25,14 +25,18 @@ if _CPU: logger.warning("Isaac Lab runs on GPU, but there is no GPU backend for JAX. JAX operations will run on CPU.") + def _jax2torch(array, device, from_jax=True): if from_jax: return torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(array)).to(device=device) return torch.tensor(array, device=device) + def _torch2jax(tensor, to_jax=True): if to_jax: - return jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(tensor.contiguous().cpu() if _CPU else tensor.contiguous())) + return jax_dlpack.from_dlpack( + torch_dlpack.to_dlpack(tensor.contiguous().cpu() if _CPU else tensor.contiguous()) + ) return tensor.cpu().numpy() @@ -52,8 +56,7 @@ def __init__(self, env: Any) -> None: @property def state_space(self) -> Union[gymnasium.Space, None]: - """State space - """ + """State space""" try: return self._unwrapped.single_observation_space["critic"] except KeyError: @@ -65,8 +68,7 @@ def state_space(self) -> Union[gymnasium.Space, None]: @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" try: return self._unwrapped.single_observation_space["policy"] except: @@ -74,16 +76,19 @@ def observation_space(self) -> gymnasium.Space: @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" try: return self._unwrapped.single_action_space except: return self._unwrapped.action_space - def step(self, actions: Union[np.ndarray, jax.Array]) -> \ - Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], - Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Any]: + def step(self, actions: Union[np.ndarray, jax.Array]) -> Tuple[ + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Any, + ]: """Perform a step in the environment :param actions: The actions to perform @@ -103,11 +108,13 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \ truncated = truncated.to(dtype=torch.int8) self._observations = _torch2jax(observations, self._jax) - return self._observations, \ - _torch2jax(reward.view(-1, 1), self._jax), \ - _torch2jax(terminated.view(-1, 1), self._jax), \ - _torch2jax(truncated.view(-1, 1), self._jax), \ - self._info + return ( + self._observations, + _torch2jax(reward.view(-1, 1), self._jax), + _torch2jax(terminated.view(-1, 1), self._jax), + _torch2jax(truncated.view(-1, 1), self._jax), + self._info, + ) def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: """Reset the environment @@ -123,13 +130,11 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: return self._observations, self._info def render(self, *args, **kwargs) -> None: - """Render the environment - """ + """Render the environment""" return None def close(self) -> None: - """Close the environment - """ + """Close the environment""" self._env.close() @@ -147,9 +152,13 @@ def __init__(self, env: Any) -> None: self._observations = None self._info = {} - def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \ - Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]], - Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Any]]: + def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> Tuple[ + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Any], + ]: """Perform a step in the environment :param actions: The actions to perform @@ -163,14 +172,18 @@ def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \ with torch.no_grad(): observations, rewards, terminated, truncated, self._info = self._env.step(actions) - observations = {k: flatten_tensorized_space(tensorize_space(self.observation_spaces[k], v)) for k, v in observations.items()} + observations = { + k: flatten_tensorized_space(tensorize_space(self.observation_spaces[k], v)) for k, v in observations.items() + } self._observations = {uid: _torch2jax(value, self._jax) for uid, value in observations.items()} - return self._observations, \ - {uid: _torch2jax(value.view(-1, 1), self._jax) for uid, value in rewards.items()}, \ - {uid: _torch2jax(value.to(dtype=torch.int8).view(-1, 1), self._jax) for uid, value in terminated.items()}, \ - {uid: _torch2jax(value.to(dtype=torch.int8).view(-1, 1), self._jax) for uid, value in truncated.items()}, \ - self._info + return ( + self._observations, + {uid: _torch2jax(value.view(-1, 1), self._jax) for uid, value in rewards.items()}, + {uid: _torch2jax(value.to(dtype=torch.int8).view(-1, 1), self._jax) for uid, value in terminated.items()}, + {uid: _torch2jax(value.to(dtype=torch.int8).view(-1, 1), self._jax) for uid, value in truncated.items()}, + self._info, + ) def reset(self) -> Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Any]]: """Reset the environment @@ -180,7 +193,10 @@ def reset(self) -> Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str """ if self._reset_once: observations, self._info = self._env.reset() - observations = {k: flatten_tensorized_space(tensorize_space(self.observation_spaces[k], v)) for k, v in observations.items()} + observations = { + k: flatten_tensorized_space(tensorize_space(self.observation_spaces[k], v)) + for k, v in observations.items() + } self._observations = {uid: _torch2jax(value, self._jax) for uid, value in observations.items()} self._reset_once = False return self._observations, self._info @@ -198,11 +214,9 @@ def state(self) -> Union[np.ndarray, jax.Array, None]: return state def render(self, *args, **kwargs) -> None: - """Render the environment - """ + """Render the environment""" return None def close(self) -> None: - """Close the environment - """ + """Close the environment""" self._env.close() diff --git a/skrl/envs/wrappers/jax/omniverse_isaacgym_envs.py b/skrl/envs/wrappers/jax/omniverse_isaacgym_envs.py index 1ecde802..9a1f2865 100644 --- a/skrl/envs/wrappers/jax/omniverse_isaacgym_envs.py +++ b/skrl/envs/wrappers/jax/omniverse_isaacgym_envs.py @@ -23,14 +23,18 @@ if _CPU: logger.warning("OmniIsaacGymEnvs runs on GPU, but there is no GPU backend for JAX. JAX operations will run on CPU.") + def _jax2torch(array, device, from_jax=True): if from_jax: return torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(array)).to(device=device) return torch.tensor(array, device=device) + def _torch2jax(tensor, to_jax=True): if to_jax: - return jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(tensor.contiguous().cpu() if _CPU else tensor.contiguous())) + return jax_dlpack.from_dlpack( + torch_dlpack.to_dlpack(tensor.contiguous().cpu() if _CPU else tensor.contiguous()) + ) return tensor.cpu().numpy() @@ -58,9 +62,13 @@ def run(self, trainer: Optional["omni.isaac.gym.vec_env.vec_env_mt.TrainerMT"] = """ self._env.run(trainer) - def step(self, actions: Union[np.ndarray, jax.Array]) -> \ - Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], - Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Any]: + def step(self, actions: Union[np.ndarray, jax.Array]) -> Tuple[ + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Any, + ]: """Perform a step in the environment :param actions: The actions to perform @@ -72,18 +80,24 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \ actions = _jax2torch(actions, self._env_device, self._jax) with torch.no_grad(): - observations, reward, terminated, self._info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) + observations, reward, terminated, self._info = self._env.step( + unflatten_tensorized_space(self.action_space, actions) + ) observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) terminated = terminated.to(dtype=torch.int8) - truncated = self._info["time_outs"].to(dtype=torch.int8) if "time_outs" in self._info else torch.zeros_like(terminated) + truncated = ( + self._info["time_outs"].to(dtype=torch.int8) if "time_outs" in self._info else torch.zeros_like(terminated) + ) self._observations = _torch2jax(observations, self._jax) - return self._observations, \ - _torch2jax(reward.view(-1, 1), self._jax), \ - _torch2jax(terminated.view(-1, 1), self._jax), \ - _torch2jax(truncated.view(-1, 1), self._jax), \ - self._info + return ( + self._observations, + _torch2jax(reward.view(-1, 1), self._jax), + _torch2jax(terminated.view(-1, 1), self._jax), + _torch2jax(truncated.view(-1, 1), self._jax), + self._info, + ) def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: """Reset the environment @@ -99,11 +113,9 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: return self._observations, self._info def render(self, *args, **kwargs) -> None: - """Render the environment - """ + """Render the environment""" return None def close(self) -> None: - """Close the environment - """ + """Close the environment""" self._env.close() diff --git a/skrl/envs/wrappers/jax/pettingzoo_envs.py b/skrl/envs/wrappers/jax/pettingzoo_envs.py index 180e0209..5381719a 100644 --- a/skrl/envs/wrappers/jax/pettingzoo_envs.py +++ b/skrl/envs/wrappers/jax/pettingzoo_envs.py @@ -10,7 +10,7 @@ flatten_tensorized_space, tensorize_space, unflatten_tensorized_space, - untensorize_space + untensorize_space, ) @@ -23,10 +23,13 @@ def __init__(self, env: Any) -> None: """ super().__init__(env) - def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \ - Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]], - Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Union[np.ndarray, jax.Array]], - Mapping[str, Any]]: + def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> Tuple[ + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Union[np.ndarray, jax.Array]], + Mapping[str, Any], + ]: """Perform a step in the environment :param actions: The actions to perform @@ -37,13 +40,23 @@ def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> \ """ if self._jax: actions = jax.device_get(actions) - actions = {uid: untensorize_space(self.action_spaces[uid], unflatten_tensorized_space(self.action_spaces[uid], action)) for uid, action in actions.items()} + actions = { + uid: untensorize_space(self.action_spaces[uid], unflatten_tensorized_space(self.action_spaces[uid], action)) + for uid, action in actions.items() + } observations, rewards, terminated, truncated, infos = self._env.step(actions) # convert response to numpy or jax - observations = {uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, self.device, False), False) for uid, value in observations.items()} + observations = { + uid: flatten_tensorized_space( + tensorize_space(self.observation_spaces[uid], value, self.device, False), False + ) + for uid, value in observations.items() + } rewards = {uid: np.array(value, dtype=np.float32).reshape(self.num_envs, -1) for uid, value in rewards.items()} - terminated = {uid: np.array(value, dtype=np.int8).reshape(self.num_envs, -1) for uid, value in terminated.items()} + terminated = { + uid: np.array(value, dtype=np.int8).reshape(self.num_envs, -1) for uid, value in terminated.items() + } truncated = {uid: np.array(value, dtype=np.int8).reshape(self.num_envs, -1) for uid, value in truncated.items()} if self._jax: observations = {uid: jax.device_put(value, device=self.device) for uid, value in observations.items()} @@ -58,7 +71,9 @@ def state(self) -> Union[np.ndarray, jax.Array]: :return: State :rtype: np.ndarray or jax.Array """ - state = flatten_tensorized_space(tensorize_space(next(iter(self.state_spaces.values())), self._env.state(), self.device, False), False) + state = flatten_tensorized_space( + tensorize_space(next(iter(self.state_spaces.values())), self._env.state(), self.device, False), False + ) if self._jax: state = jax.device_put(state, device=self.device) return state @@ -77,17 +92,20 @@ def reset(self) -> Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str observations, infos = outputs # convert response to numpy or jax - observations = {uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, self.device, False), False) for uid, value in observations.items()} + observations = { + uid: flatten_tensorized_space( + tensorize_space(self.observation_spaces[uid], value, self.device, False), False + ) + for uid, value in observations.items() + } if self._jax: observations = {uid: jax.device_put(value, device=self.device) for uid, value in observations.items()} return observations, infos def render(self, *args, **kwargs) -> Any: - """Render the environment - """ + """Render the environment""" return self._env.render(*args, **kwargs) def close(self) -> None: - """Close the environment - """ + """Close the environment""" self._env.close() diff --git a/skrl/envs/wrappers/torch/__init__.py b/skrl/envs/wrappers/torch/__init__.py index 4983ddf2..5fcd6560 100644 --- a/skrl/envs/wrappers/torch/__init__.py +++ b/skrl/envs/wrappers/torch/__init__.py @@ -81,6 +81,7 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True) -> Union[Wra :return: Wrapped environment :rtype: Wrapper or MultiAgentEnvWrapper """ + def _get_wrapper_name(env, verbose): def _in(values, container): if type(values) == str: @@ -93,7 +94,9 @@ def _in(values, container): base_classes = [str(base).replace("", "") for base in env.__class__.__bases__] try: - base_classes += [str(base).replace("", "") for base in env.unwrapped.__class__.__bases__] + base_classes += [ + str(base).replace("", "") for base in env.unwrapped.__class__.__bases__ + ] except: pass base_classes = sorted(list(set(base_classes))) diff --git a/skrl/envs/wrappers/torch/base.py b/skrl/envs/wrappers/torch/base.py index 7b6893e7..1483c559 100644 --- a/skrl/envs/wrappers/torch/base.py +++ b/skrl/envs/wrappers/torch/base.py @@ -39,7 +39,9 @@ def __getattr__(self, key: str) -> Any: return getattr(self._env, key) if hasattr(self._unwrapped, key): return getattr(self._unwrapped, key) - raise AttributeError(f"Wrapped environment ({self._unwrapped.__class__.__name__}) does not have attribute '{key}'") + raise AttributeError( + f"Wrapped environment ({self._unwrapped.__class__.__name__}) does not have attribute '{key}'" + ) def reset(self) -> Tuple[torch.Tensor, Any]: """Reset the environment @@ -126,14 +128,12 @@ def state_space(self) -> Union[gymnasium.Space, None]: @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" return self._unwrapped.observation_space @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" return self._unwrapped.action_space @@ -171,7 +171,9 @@ def __getattr__(self, key: str) -> Any: return getattr(self._env, key) if hasattr(self._unwrapped, key): return getattr(self._unwrapped, key) - raise AttributeError(f"Wrapped environment ({self._unwrapped.__class__.__name__}) does not have attribute '{key}'") + raise AttributeError( + f"Wrapped environment ({self._unwrapped.__class__.__name__}) does not have attribute '{key}'" + ) def reset(self) -> Tuple[Mapping[str, torch.Tensor], Mapping[str, Any]]: """Reset the environment @@ -183,9 +185,13 @@ def reset(self) -> Tuple[Mapping[str, torch.Tensor], Mapping[str, Any]]: """ raise NotImplementedError - def step(self, actions: Mapping[str, torch.Tensor]) -> \ - Tuple[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor], - Mapping[str, torch.Tensor], Mapping[str, torch.Tensor], Mapping[str, Any]]: + def step(self, actions: Mapping[str, torch.Tensor]) -> Tuple[ + Mapping[str, torch.Tensor], + Mapping[str, torch.Tensor], + Mapping[str, torch.Tensor], + Mapping[str, torch.Tensor], + Mapping[str, Any], + ]: """Perform a step in the environment :param actions: The actions to perform @@ -293,14 +299,12 @@ def state_spaces(self) -> Mapping[str, gymnasium.Space]: @property def observation_spaces(self) -> Mapping[str, gymnasium.Space]: - """Observation spaces - """ + """Observation spaces""" return self._unwrapped.observation_spaces @property def action_spaces(self) -> Mapping[str, gymnasium.Space]: - """Action spaces - """ + """Action spaces""" return self._unwrapped.action_spaces def state_space(self, agent: str) -> gymnasium.Space: diff --git a/skrl/envs/wrappers/torch/bidexhands_envs.py b/skrl/envs/wrappers/torch/bidexhands_envs.py index 827c181f..cdaf953d 100644 --- a/skrl/envs/wrappers/torch/bidexhands_envs.py +++ b/skrl/envs/wrappers/torch/bidexhands_envs.py @@ -46,23 +46,27 @@ def state_spaces(self) -> Mapping[str, gymnasium.Space]: this property returns a dictionary (for consistency with the other space-related properties) with the same space for all the agents """ - return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.share_observation_space)} + return { + uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.share_observation_space) + } @property def observation_spaces(self) -> Mapping[str, gymnasium.Space]: - """Observation spaces - """ + """Observation spaces""" return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.observation_space)} @property def action_spaces(self) -> Mapping[str, gymnasium.Space]: - """Action spaces - """ + """Action spaces""" return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.action_space)} - def step(self, actions: Mapping[str, torch.Tensor]) -> \ - Tuple[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor], - Mapping[str, torch.Tensor], Mapping[str, torch.Tensor], Mapping[str, Any]]: + def step(self, actions: Mapping[str, torch.Tensor]) -> Tuple[ + Mapping[str, torch.Tensor], + Mapping[str, torch.Tensor], + Mapping[str, torch.Tensor], + Mapping[str, torch.Tensor], + Mapping[str, Any], + ]: """Perform a step in the environment :param actions: The actions to perform @@ -75,9 +79,9 @@ def step(self, actions: Mapping[str, torch.Tensor]) -> \ observations, states, rewards, terminated, _, _ = self._env.step(actions) self._states = states[:, 0] - self._observations = {uid: observations[:,i] for i, uid in enumerate(self.possible_agents)} - rewards = {uid: rewards[:,i].view(-1, 1) for i, uid in enumerate(self.possible_agents)} - terminated = {uid: terminated[:,i].view(-1, 1) for i, uid in enumerate(self.possible_agents)} + self._observations = {uid: observations[:, i] for i, uid in enumerate(self.possible_agents)} + rewards = {uid: rewards[:, i].view(-1, 1) for i, uid in enumerate(self.possible_agents)} + terminated = {uid: terminated[:, i].view(-1, 1) for i, uid in enumerate(self.possible_agents)} truncated = {uid: torch.zeros_like(value) for uid, value in terminated.items()} return self._observations, rewards, terminated, truncated, self._info @@ -99,16 +103,14 @@ def reset(self) -> Tuple[Mapping[str, torch.Tensor], Mapping[str, Any]]: if self._reset_once: observations, states, _ = self._env.reset() self._states = states[:, 0] - self._observations = {uid: observations[:,i] for i, uid in enumerate(self.possible_agents)} + self._observations = {uid: observations[:, i] for i, uid in enumerate(self.possible_agents)} self._reset_once = False return self._observations, self._info def render(self, *args, **kwargs) -> None: - """Render the environment - """ + """Render the environment""" return None def close(self) -> None: - """Close the environment - """ + """Close the environment""" pass diff --git a/skrl/envs/wrappers/torch/brax_envs.py b/skrl/envs/wrappers/torch/brax_envs.py index 775994f9..fe1281e1 100644 --- a/skrl/envs/wrappers/torch/brax_envs.py +++ b/skrl/envs/wrappers/torch/brax_envs.py @@ -10,7 +10,7 @@ convert_gym_space, flatten_tensorized_space, tensorize_space, - unflatten_tensorized_space + unflatten_tensorized_space, ) @@ -25,6 +25,7 @@ def __init__(self, env: Any) -> None: import brax.envs.wrappers.gym import brax.envs.wrappers.torch + env = brax.envs.wrappers.gym.VectorGymWrapper(env) env = brax.envs.wrappers.torch.TorchWrapper(env, device=self.device) self._env = env @@ -32,14 +33,12 @@ def __init__(self, env: Any) -> None: @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" return convert_gym_space(self._unwrapped.observation_space, squeeze_batch_dimension=True) @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" return convert_gym_space(self._unwrapped.action_space, squeeze_batch_dimension=True) def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: @@ -67,13 +66,13 @@ def reset(self) -> Tuple[torch.Tensor, Any]: return observation, {} def render(self, *args, **kwargs) -> None: - """Render the environment - """ + """Render the environment""" frame = self._env.render(mode="rgb_array") # render the frame using OpenCV try: import cv2 + cv2.imshow("env", cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) cv2.waitKey(1) except ImportError as e: @@ -81,7 +80,6 @@ def render(self, *args, **kwargs) -> None: return frame def close(self) -> None: - """Close the environment - """ + """Close the environment""" # self._env.close() raises AttributeError: 'VectorGymWrapper' object has no attribute 'closed' pass diff --git a/skrl/envs/wrappers/torch/deepmind_envs.py b/skrl/envs/wrappers/torch/deepmind_envs.py index 1777703a..bb7d70be 100644 --- a/skrl/envs/wrappers/torch/deepmind_envs.py +++ b/skrl/envs/wrappers/torch/deepmind_envs.py @@ -12,7 +12,7 @@ flatten_tensorized_space, tensorize_space, unflatten_tensorized_space, - untensorize_space + untensorize_space, ) @@ -26,18 +26,17 @@ def __init__(self, env: Any) -> None: super().__init__(env) from dm_env import specs + self._specs = specs @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" return self._spec_to_space(self._env.observation_spec()) @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" return self._spec_to_space(self._env.action_spec()) def _spec_to_space(self, spec: Any) -> gymnasium.Space: @@ -54,15 +53,19 @@ def _spec_to_space(self, spec: Any) -> gymnasium.Space: if isinstance(spec, self._specs.DiscreteArray): return gymnasium.spaces.Discrete(spec.num_values) elif isinstance(spec, self._specs.BoundedArray): - return gymnasium.spaces.Box(shape=spec.shape, - dtype=spec.dtype, - low=spec.minimum if spec.minimum.ndim else np.full(spec.shape, spec.minimum), - high=spec.maximum if spec.maximum.ndim else np.full(spec.shape, spec.maximum)) + return gymnasium.spaces.Box( + shape=spec.shape, + dtype=spec.dtype, + low=spec.minimum if spec.minimum.ndim else np.full(spec.shape, spec.minimum), + high=spec.maximum if spec.maximum.ndim else np.full(spec.shape, spec.maximum), + ) elif isinstance(spec, self._specs.Array): - return gymnasium.spaces.Box(shape=spec.shape, - dtype=spec.dtype, - low=np.full(spec.shape, float("-inf")), - high=np.full(spec.shape, float("inf"))) + return gymnasium.spaces.Box( + shape=spec.shape, + dtype=spec.dtype, + low=np.full(spec.shape, float("-inf")), + high=np.full(spec.shape, float("inf")), + ) elif isinstance(spec, collections.OrderedDict): return gymnasium.spaces.Dict({k: self._spec_to_space(v) for k, v in spec.items()}) else: @@ -80,18 +83,22 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch actions = untensorize_space(self.action_space, unflatten_tensorized_space(self.action_space, actions)) timestep = self._env.step(actions) - observation = flatten_tensorized_space(tensorize_space(self.observation_space, timestep.observation, self.device)) + observation = flatten_tensorized_space( + tensorize_space(self.observation_space, timestep.observation, self.device) + ) reward = timestep.reward if timestep.reward is not None else 0 terminated = timestep.last() truncated = False info = {} # convert response to torch - return observation, \ - torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1), \ - torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1), \ - torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1), \ - info + return ( + observation, + torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1), + torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1), + torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1), + info, + ) def reset(self) -> Tuple[torch.Tensor, Any]: """Reset the environment @@ -100,7 +107,9 @@ def reset(self) -> Tuple[torch.Tensor, Any]: :rtype: torch.Tensor """ timestep = self._env.reset() - observation = flatten_tensorized_space(tensorize_space(self.observation_space, timestep.observation, self.device)) + observation = flatten_tensorized_space( + tensorize_space(self.observation_space, timestep.observation, self.device) + ) return observation, {} def render(self, *args, **kwargs) -> np.ndarray: @@ -114,6 +123,7 @@ def render(self, *args, **kwargs) -> np.ndarray: # render the frame using OpenCV try: import cv2 + cv2.imshow("env", cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) cv2.waitKey(1) except ImportError as e: @@ -121,6 +131,5 @@ def render(self, *args, **kwargs) -> np.ndarray: return frame def close(self) -> None: - """Close the environment - """ + """Close the environment""" self._env.close() diff --git a/skrl/envs/wrappers/torch/gym_envs.py b/skrl/envs/wrappers/torch/gym_envs.py index a432d211..887896f9 100644 --- a/skrl/envs/wrappers/torch/gym_envs.py +++ b/skrl/envs/wrappers/torch/gym_envs.py @@ -13,7 +13,7 @@ flatten_tensorized_space, tensorize_space, unflatten_tensorized_space, - untensorize_space + untensorize_space, ) @@ -33,6 +33,7 @@ def __init__(self, env: Any) -> None: np.bool8 = np.bool import gym + self._vectorized = False try: if isinstance(env, gym.vector.VectorEnv): @@ -49,16 +50,14 @@ def __init__(self, env: Any) -> None: @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" if self._vectorized: return convert_gym_space(self._env.single_observation_space) return convert_gym_space(self._env.observation_space) @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" if self._vectorized: return convert_gym_space(self._env.single_action_space) return convert_gym_space(self._env.action_space) @@ -72,9 +71,11 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch :return: Observation, reward, terminated, truncated, info :rtype: tuple of torch.Tensor and any other info """ - actions = untensorize_space(self.action_space, - unflatten_tensorized_space(self.action_space, actions), - squeeze_batch_dimension=not self._vectorized) + actions = untensorize_space( + self.action_space, + unflatten_tensorized_space(self.action_space, actions), + squeeze_batch_dimension=not self._vectorized, + ) if self._deprecated_api: observation, reward, terminated, info = self._env.step(actions) @@ -116,7 +117,9 @@ def reset(self) -> Tuple[torch.Tensor, Any]: self._info = {} else: observation, self._info = self._env.reset() - self._observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device)) + self._observation = flatten_tensorized_space( + tensorize_space(self.observation_space, observation, self.device) + ) self._reset_once = False return self._observation, self._info @@ -129,13 +132,11 @@ def reset(self) -> Tuple[torch.Tensor, Any]: return observation, info def render(self, *args, **kwargs) -> Any: - """Render the environment - """ + """Render the environment""" if self._vectorized: return None return self._env.render(*args, **kwargs) def close(self) -> None: - """Close the environment - """ + """Close the environment""" self._env.close() diff --git a/skrl/envs/wrappers/torch/gymnasium_envs.py b/skrl/envs/wrappers/torch/gymnasium_envs.py index 1de58993..1708756e 100644 --- a/skrl/envs/wrappers/torch/gymnasium_envs.py +++ b/skrl/envs/wrappers/torch/gymnasium_envs.py @@ -10,7 +10,7 @@ flatten_tensorized_space, tensorize_space, unflatten_tensorized_space, - untensorize_space + untensorize_space, ) @@ -35,16 +35,14 @@ def __init__(self, env: Any) -> None: @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" if self._vectorized: return self._env.single_observation_space return self._env.observation_space @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" if self._vectorized: return self._env.single_action_space return self._env.action_space @@ -58,9 +56,11 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch :return: Observation, reward, terminated, truncated, info :rtype: tuple of torch.Tensor and any other info """ - actions = untensorize_space(self.action_space, - unflatten_tensorized_space(self.action_space, actions), - squeeze_batch_dimension=not self._vectorized) + actions = untensorize_space( + self.action_space, + unflatten_tensorized_space(self.action_space, actions), + squeeze_batch_dimension=not self._vectorized, + ) observation, reward, terminated, truncated, info = self._env.step(actions) @@ -87,7 +87,9 @@ def reset(self) -> Tuple[torch.Tensor, Any]: if self._vectorized: if self._reset_once: observation, self._info = self._env.reset() - self._observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, self.device)) + self._observation = flatten_tensorized_space( + tensorize_space(self.observation_space, observation, self.device) + ) self._reset_once = False return self._observation, self._info @@ -96,13 +98,11 @@ def reset(self) -> Tuple[torch.Tensor, Any]: return observation, info def render(self, *args, **kwargs) -> Any: - """Render the environment - """ + """Render the environment""" if self._vectorized: return self._env.call("render", *args, **kwargs) return self._env.render(*args, **kwargs) def close(self) -> None: - """Close the environment - """ + """Close the environment""" self._env.close() diff --git a/skrl/envs/wrappers/torch/isaacgym_envs.py b/skrl/envs/wrappers/torch/isaacgym_envs.py index 0d5a3cd2..becfcd4a 100644 --- a/skrl/envs/wrappers/torch/isaacgym_envs.py +++ b/skrl/envs/wrappers/torch/isaacgym_envs.py @@ -9,7 +9,7 @@ convert_gym_space, flatten_tensorized_space, tensorize_space, - unflatten_tensorized_space + unflatten_tensorized_space, ) @@ -28,14 +28,12 @@ def __init__(self, env: Any) -> None: @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" return convert_gym_space(self._unwrapped.observation_space) @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" return convert_gym_space(self._unwrapped.action_space) def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: @@ -47,7 +45,9 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch :return: Observation, reward, terminated, truncated, info :rtype: tuple of torch.Tensor and any other info """ - observations, reward, terminated, self._info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) + observations, reward, terminated, self._info = self._env.step( + unflatten_tensorized_space(self.action_space, actions) + ) self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations)) truncated = self._info["time_outs"] if "time_outs" in self._info else torch.zeros_like(terminated) return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info @@ -65,13 +65,11 @@ def reset(self) -> Tuple[torch.Tensor, Any]: return self._observations, self._info def render(self, *args, **kwargs) -> None: - """Render the environment - """ + """Render the environment""" return None def close(self) -> None: - """Close the environment - """ + """Close the environment""" pass @@ -90,20 +88,17 @@ def __init__(self, env: Any) -> None: @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" return convert_gym_space(self._unwrapped.observation_space) @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" return convert_gym_space(self._unwrapped.action_space) @property def state_space(self) -> Union[gymnasium.Space, None]: - """State space - """ + """State space""" try: if self.num_states: return convert_gym_space(self._unwrapped.state_space) @@ -120,7 +115,9 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch :return: Observation, reward, terminated, truncated, info :rtype: tuple of torch.Tensor and any other info """ - observations, reward, terminated, self._info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) + observations, reward, terminated, self._info = self._env.step( + unflatten_tensorized_space(self.action_space, actions) + ) self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) truncated = self._info["time_outs"] if "time_outs" in self._info else torch.zeros_like(terminated) return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info @@ -138,11 +135,9 @@ def reset(self) -> Tuple[torch.Tensor, Any]: return self._observations, self._info def render(self, *args, **kwargs) -> None: - """Render the environment - """ + """Render the environment""" return None def close(self) -> None: - """Close the environment - """ + """Close the environment""" pass diff --git a/skrl/envs/wrappers/torch/isaaclab_envs.py b/skrl/envs/wrappers/torch/isaaclab_envs.py index c3ed2589..e354c196 100644 --- a/skrl/envs/wrappers/torch/isaaclab_envs.py +++ b/skrl/envs/wrappers/torch/isaaclab_envs.py @@ -23,8 +23,7 @@ def __init__(self, env: Any) -> None: @property def state_space(self) -> Union[gymnasium.Space, None]: - """State space - """ + """State space""" try: return self._unwrapped.single_observation_space["critic"] except KeyError: @@ -36,8 +35,7 @@ def state_space(self) -> Union[gymnasium.Space, None]: @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" try: return self._unwrapped.single_observation_space["policy"] except: @@ -45,8 +43,7 @@ def observation_space(self) -> gymnasium.Space: @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" try: return self._unwrapped.single_action_space except: @@ -74,18 +71,18 @@ def reset(self) -> Tuple[torch.Tensor, Any]: """ if self._reset_once: observations, self._info = self._env.reset() - self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["policy"])) + self._observations = flatten_tensorized_space( + tensorize_space(self.observation_space, observations["policy"]) + ) self._reset_once = False return self._observations, self._info def render(self, *args, **kwargs) -> None: - """Render the environment - """ + """Render the environment""" return None def close(self) -> None: - """Close the environment - """ + """Close the environment""" self._env.close() @@ -102,9 +99,13 @@ def __init__(self, env: Any) -> None: self._observations = None self._info = {} - def step(self, actions: Mapping[str, torch.Tensor]) -> \ - Tuple[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor], - Mapping[str, torch.Tensor], Mapping[str, torch.Tensor], Mapping[str, Any]]: + def step(self, actions: Mapping[str, torch.Tensor]) -> Tuple[ + Mapping[str, torch.Tensor], + Mapping[str, torch.Tensor], + Mapping[str, torch.Tensor], + Mapping[str, torch.Tensor], + Mapping[str, Any], + ]: """Perform a step in the environment :param actions: The actions to perform @@ -115,12 +116,16 @@ def step(self, actions: Mapping[str, torch.Tensor]) -> \ """ actions = {k: unflatten_tensorized_space(self.action_spaces[k], v) for k, v in actions.items()} observations, rewards, terminated, truncated, self._info = self._env.step(actions) - self._observations = {k: flatten_tensorized_space(tensorize_space(self.observation_spaces[k], v)) for k, v in observations.items()} - return self._observations, \ - {k: v.view(-1, 1) for k, v in rewards.items()}, \ - {k: v.view(-1, 1) for k, v in terminated.items()}, \ - {k: v.view(-1, 1) for k, v in truncated.items()}, \ - self._info + self._observations = { + k: flatten_tensorized_space(tensorize_space(self.observation_spaces[k], v)) for k, v in observations.items() + } + return ( + self._observations, + {k: v.view(-1, 1) for k, v in rewards.items()}, + {k: v.view(-1, 1) for k, v in terminated.items()}, + {k: v.view(-1, 1) for k, v in truncated.items()}, + self._info, + ) def reset(self) -> Tuple[Mapping[str, torch.Tensor], Mapping[str, Any]]: """Reset the environment @@ -130,7 +135,10 @@ def reset(self) -> Tuple[Mapping[str, torch.Tensor], Mapping[str, Any]]: """ if self._reset_once: observations, self._info = self._env.reset() - self._observations = {k: flatten_tensorized_space(tensorize_space(self.observation_spaces[k], v)) for k, v in observations.items()} + self._observations = { + k: flatten_tensorized_space(tensorize_space(self.observation_spaces[k], v)) + for k, v in observations.items() + } self._reset_once = False return self._observations, self._info @@ -146,11 +154,9 @@ def state(self) -> torch.Tensor: return state def render(self, *args, **kwargs) -> None: - """Render the environment - """ + """Render the environment""" return None def close(self) -> None: - """Close the environment - """ + """Close the environment""" self._env.close() diff --git a/skrl/envs/wrappers/torch/omniverse_isaacgym_envs.py b/skrl/envs/wrappers/torch/omniverse_isaacgym_envs.py index 88f11a18..cae65abc 100644 --- a/skrl/envs/wrappers/torch/omniverse_isaacgym_envs.py +++ b/skrl/envs/wrappers/torch/omniverse_isaacgym_envs.py @@ -38,7 +38,9 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch :return: Observation, reward, terminated, truncated, info :rtype: tuple of torch.Tensor and any other info """ - observations, reward, terminated, self._info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) + observations, reward, terminated, self._info = self._env.step( + unflatten_tensorized_space(self.action_space, actions) + ) self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) truncated = self._info["time_outs"] if "time_outs" in self._info else torch.zeros_like(terminated) return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info @@ -56,11 +58,9 @@ def reset(self) -> Tuple[torch.Tensor, Any]: return self._observations, self._info def render(self, *args, **kwargs) -> None: - """Render the environment - """ + """Render the environment""" return None def close(self) -> None: - """Close the environment - """ + """Close the environment""" self._env.close() diff --git a/skrl/envs/wrappers/torch/pettingzoo_envs.py b/skrl/envs/wrappers/torch/pettingzoo_envs.py index 7b55785c..f323aa96 100644 --- a/skrl/envs/wrappers/torch/pettingzoo_envs.py +++ b/skrl/envs/wrappers/torch/pettingzoo_envs.py @@ -9,7 +9,7 @@ flatten_tensorized_space, tensorize_space, unflatten_tensorized_space, - untensorize_space + untensorize_space, ) @@ -22,9 +22,13 @@ def __init__(self, env: Any) -> None: """ super().__init__(env) - def step(self, actions: Mapping[str, torch.Tensor]) -> \ - Tuple[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor], - Mapping[str, torch.Tensor], Mapping[str, torch.Tensor], Mapping[str, Any]]: + def step(self, actions: Mapping[str, torch.Tensor]) -> Tuple[ + Mapping[str, torch.Tensor], + Mapping[str, torch.Tensor], + Mapping[str, torch.Tensor], + Mapping[str, torch.Tensor], + Mapping[str, Any], + ]: """Perform a step in the environment :param actions: The actions to perform @@ -33,14 +37,29 @@ def step(self, actions: Mapping[str, torch.Tensor]) -> \ :return: Observation, reward, terminated, truncated, info :rtype: tuple of dictionaries torch.Tensor and any other info """ - actions = {uid: untensorize_space(self.action_spaces[uid], unflatten_tensorized_space(self.action_spaces[uid], action)) for uid, action in actions.items()} + actions = { + uid: untensorize_space(self.action_spaces[uid], unflatten_tensorized_space(self.action_spaces[uid], action)) + for uid, action in actions.items() + } observations, rewards, terminated, truncated, infos = self._env.step(actions) # convert response to torch - observations = {uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, device=self.device)) for uid, value in observations.items()} - rewards = {uid: torch.tensor(value, device=self.device, dtype=torch.float32).view(self.num_envs, -1) for uid, value in rewards.items()} - terminated = {uid: torch.tensor(value, device=self.device, dtype=torch.bool).view(self.num_envs, -1) for uid, value in terminated.items()} - truncated = {uid: torch.tensor(value, device=self.device, dtype=torch.bool).view(self.num_envs, -1) for uid, value in truncated.items()} + observations = { + uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, device=self.device)) + for uid, value in observations.items() + } + rewards = { + uid: torch.tensor(value, device=self.device, dtype=torch.float32).view(self.num_envs, -1) + for uid, value in rewards.items() + } + terminated = { + uid: torch.tensor(value, device=self.device, dtype=torch.bool).view(self.num_envs, -1) + for uid, value in terminated.items() + } + truncated = { + uid: torch.tensor(value, device=self.device, dtype=torch.bool).view(self.num_envs, -1) + for uid, value in truncated.items() + } return observations, rewards, terminated, truncated, infos def state(self) -> torch.Tensor: @@ -49,7 +68,9 @@ def state(self) -> torch.Tensor: :return: State :rtype: torch.Tensor """ - return flatten_tensorized_space(tensorize_space(next(iter(self.state_spaces.values())), self._env.state(), device=self.device)) + return flatten_tensorized_space( + tensorize_space(next(iter(self.state_spaces.values())), self._env.state(), device=self.device) + ) def reset(self) -> Tuple[Mapping[str, torch.Tensor], Mapping[str, Any]]: """Reset the environment @@ -65,15 +86,16 @@ def reset(self) -> Tuple[Mapping[str, torch.Tensor], Mapping[str, Any]]: observations, infos = outputs # convert response to torch - observations = {uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, device=self.device)) for uid, value in observations.items()} + observations = { + uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, device=self.device)) + for uid, value in observations.items() + } return observations, infos def render(self, *args, **kwargs) -> Any: - """Render the environment - """ + """Render the environment""" return self._env.render(*args, **kwargs) def close(self) -> None: - """Close the environment - """ + """Close the environment""" self._env.close() diff --git a/skrl/envs/wrappers/torch/robosuite_envs.py b/skrl/envs/wrappers/torch/robosuite_envs.py index d07b438b..d26db207 100644 --- a/skrl/envs/wrappers/torch/robosuite_envs.py +++ b/skrl/envs/wrappers/torch/robosuite_envs.py @@ -33,14 +33,12 @@ def state_space(self) -> gymnasium.Space: @property def observation_space(self) -> gymnasium.Space: - """Observation space - """ + """Observation space""" return convert_gym_space(self._observation_space) @property def action_space(self) -> gymnasium.Space: - """Action space - """ + """Action space""" return convert_gym_space(self._action_space) def _spec_to_space(self, spec: Any) -> gymnasium.Space: @@ -55,15 +53,14 @@ def _spec_to_space(self, spec: Any) -> gymnasium.Space: :rtype: gymnasium.Space """ if type(spec) is tuple: - return gymnasium.spaces.Box(shape=spec[0].shape, - dtype=np.float32, - low=spec[0], - high=spec[1]) + return gymnasium.spaces.Box(shape=spec[0].shape, dtype=np.float32, low=spec[0], high=spec[1]) elif isinstance(spec, np.ndarray): - return gymnasium.spaces.Box(shape=spec.shape, - dtype=np.float32, - low=np.full(spec.shape, float("-inf")), - high=np.full(spec.shape, float("inf"))) + return gymnasium.spaces.Box( + shape=spec.shape, + dtype=np.float32, + low=np.full(spec.shape, float("-inf")), + high=np.full(spec.shape, float("inf")), + ) elif isinstance(spec, collections.OrderedDict): return gymnasium.spaces.Dict({k: self._spec_to_space(v) for k, v in spec.items()}) else: @@ -85,8 +82,9 @@ def _observation_to_tensor(self, observation: Any, spec: Optional[Any] = None) - if isinstance(spec, np.ndarray): return torch.tensor(observation, device=self.device, dtype=torch.float32).reshape(self.num_envs, -1) elif isinstance(spec, collections.OrderedDict): - return torch.cat([self._observation_to_tensor(observation[k], spec[k]) \ - for k in sorted(spec.keys())], dim=-1).reshape(self.num_envs, -1) + return torch.cat( + [self._observation_to_tensor(observation[k], spec[k]) for k in sorted(spec.keys())], dim=-1 + ).reshape(self.num_envs, -1) else: raise ValueError(f"Observation spec type {type(spec)} not supported. Please report this issue") @@ -122,11 +120,13 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch info = {} # convert response to torch - return self._observation_to_tensor(observation), \ - torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1), \ - torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1), \ - torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1), \ - info + return ( + self._observation_to_tensor(observation), + torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1), + torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1), + torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1), + info, + ) def reset(self) -> Tuple[torch.Tensor, Any]: """Reset the environment @@ -138,11 +138,9 @@ def reset(self) -> Tuple[torch.Tensor, Any]: return self._observation_to_tensor(observation), {} def render(self, *args, **kwargs) -> None: - """Render the environment - """ + """Render the environment""" self._env.render(*args, **kwargs) def close(self) -> None: - """Close the environment - """ + """Close the environment""" self._env.close() diff --git a/skrl/memories/jax/base.py b/skrl/memories/jax/base.py index f6168a48..630dd8e5 100644 --- a/skrl/memories/jax/base.py +++ b/skrl/memories/jax/base.py @@ -18,27 +18,30 @@ # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @jax.jit def _copyto(dst, src): - """NumPy function not yet implemented - """ + """NumPy function not yet implemented""" return dst.at[:].set(src) + @jax.jit def _copyto_i(dst, src, i): return dst.at[i].set(src) + @jax.jit def _copyto_i_j(dst, src, i, j): return dst.at[i, j].set(src) class Memory: - def __init__(self, - memory_size: int, - num_envs: int = 1, - device: Optional[jax.Device] = None, - export: bool = False, - export_format: str = "pt", # TODO: set default format for jax - export_directory: str = "") -> None: + def __init__( + self, + memory_size: int, + num_envs: int = 1, + device: Optional[jax.Device] = None, + export: bool = False, + export_format: str = "pt", # TODO: set default format for jax + export_directory: str = "", + ) -> None: """Base class representing a memory with circular buffers Buffers are jax or numpy arrays with shape (memory size, number of environments, data size). @@ -72,7 +75,7 @@ def __init__(self, else: self.device = device if type(device) == str: - device_type, device_index = f"{device}:0".split(':')[:2] + device_type, device_index = f"{device}:0".split(":")[:2] self.device = jax.devices(device_type)[int(device_index)] # internal variables @@ -86,7 +89,9 @@ def __init__(self, self._views = True # whether the views are not array copies self.sampling_indexes = None - self.all_sequence_indexes = np.concatenate([np.arange(i, memory_size * num_envs + i, num_envs) for i in range(num_envs)]) + self.all_sequence_indexes = np.concatenate( + [np.arange(i, memory_size * num_envs + i, num_envs) for i in range(num_envs)] + ) # exporting data self.export = export @@ -109,12 +114,15 @@ def __len__(self) -> int: def _get_tensors_view(self, name): if self.tensors_keep_dimensions[name]: - return self.tensors_view[name] if self._views else self.tensors[name].reshape(-1, *self.tensors_keep_dimensions[name]) + return ( + self.tensors_view[name] + if self._views + else self.tensors[name].reshape(-1, *self.tensors_keep_dimensions[name]) + ) return self.tensors_view[name] if self._views else self.tensors[name].reshape(-1, self.tensors[name].shape[-1]) def share_memory(self) -> None: - """Share the tensors between processes - """ + """Share the tensors between processes""" for tensor in self.tensors.values(): pass @@ -157,11 +165,13 @@ def set_tensor_by_name(self, name: str, tensor: Union[np.ndarray, jax.Array]) -> else: np.copyto(self.tensors[name], tensor) - def create_tensor(self, - name: str, - size: Union[int, Tuple[int], gymnasium.Space], - dtype: Optional[np.dtype] = None, - keep_dimensions: bool = False) -> bool: + def create_tensor( + self, + name: str, + size: Union[int, Tuple[int], gymnasium.Space], + dtype: Optional[np.dtype] = None, + keep_dimensions: bool = False, + ) -> bool: """Create a new internal tensor in memory The tensor will have a 3-components shape (memory size, number of environments, size). @@ -194,7 +204,9 @@ def create_tensor(self, raise ValueError(f"Dtype of tensor {name} ({dtype}) doesn't match the existing one ({tensor.dtype})") return False # define tensor shape - tensor_shape = (self.memory_size, self.num_envs, *size) if keep_dimensions else (self.memory_size, self.num_envs, size) + tensor_shape = ( + (self.memory_size, self.num_envs, *size) if keep_dimensions else (self.memory_size, self.num_envs, size) + ) view_shape = (-1, *size) if keep_dimensions else (-1, size) # create tensor (_tensor_) and add it to the internal storage if self._jax: @@ -261,7 +273,9 @@ def add_samples(self, **tensors: Mapping[str, Union[np.ndarray, jax.Array]]) -> :raises ValueError: No tensors were provided or the tensors have incompatible shapes """ if not tensors: - raise ValueError("No samples to be recorded in memory. Pass samples as key-value arguments (where key is the tensor name)") + raise ValueError( + "No samples to be recorded in memory. Pass samples as key-value arguments (where key is the tensor name)" + ) # dimensions and shapes of the tensors (assume all tensors have the dimensions of the first tensor) tmp = tensors.get("states", tensors[next(iter(tensors))]) # ask for states first @@ -283,7 +297,11 @@ def add_samples(self, **tensors: Mapping[str, Union[np.ndarray, jax.Array]]) -> raise NotImplementedError # TODO: for name, tensor in tensors.items(): if name in self.tensors: - self.tensors[name] = self.tensors[name].at[self.memory_index, self.env_index:self.env_index + tensor.shape[0]].set(tensor) + self.tensors[name] = ( + self.tensors[name] + .at[self.memory_index, self.env_index : self.env_index + tensor.shape[0]] + .set(tensor) + ) self.env_index += tensor.shape[0] # single environment - multi sample (number of environments greater than num_envs (num_envs = 1)) elif dim > 1 and self.num_envs == 1: @@ -293,11 +311,17 @@ def add_samples(self, **tensors: Mapping[str, Union[np.ndarray, jax.Array]]) -> num_samples = min(shape[0], self.memory_size - self.memory_index) remaining_samples = shape[0] - num_samples # copy the first n samples - self.tensors[name] = self.tensors[name].at[self.memory_index:self.memory_index + num_samples].set(tensor[:num_samples].unsqueeze(dim=1)) + self.tensors[name] = ( + self.tensors[name] + .at[self.memory_index : self.memory_index + num_samples] + .set(tensor[:num_samples].unsqueeze(dim=1)) + ) self.memory_index += num_samples # storage remaining samples if remaining_samples > 0: - self.tensors[name] = self.tensors[name].at[:remaining_samples].set(tensor[num_samples:].unsqueeze(dim=1)) + self.tensors[name] = ( + self.tensors[name].at[:remaining_samples].set(tensor[num_samples:].unsqueeze(dim=1)) + ) self.memory_index = remaining_samples # single environment elif dim == 1: @@ -325,11 +349,9 @@ def add_samples(self, **tensors: Mapping[str, Union[np.ndarray, jax.Array]]) -> if self.export: self.save(directory=self.export_directory, format=self.export_format) - def sample(self, - names: Tuple[str], - batch_size: int, - mini_batches: int = 1, - sequence_length: int = 1) -> List[List[Union[np.ndarray, jax.Array]]]: + def sample( + self, names: Tuple[str], batch_size: int, mini_batches: int = 1, sequence_length: int = 1 + ) -> List[List[Union[np.ndarray, jax.Array]]]: """Data sampling method to be implemented by the inheriting classes :param names: Tensors names from which to obtain the samples @@ -349,7 +371,9 @@ def sample(self, """ raise NotImplementedError("The sampling method (.sample()) is not implemented") - def sample_by_index(self, names: Tuple[str], indexes: Union[tuple, np.ndarray, jax.Array], mini_batches: int = 1) -> List[List[Union[np.ndarray, jax.Array]]]: + def sample_by_index( + self, names: Tuple[str], indexes: Union[tuple, np.ndarray, jax.Array], mini_batches: int = 1 + ) -> List[List[Union[np.ndarray, jax.Array]]]: """Sample data from memory according to their indexes :param names: Tensors names from which to obtain the samples @@ -369,7 +393,9 @@ def sample_by_index(self, names: Tuple[str], indexes: Union[tuple, np.ndarray, j return [[view[batch] for view in views] for batch in batches] return [[self._get_tensors_view(name)[indexes] for name in names]] - def sample_all(self, names: Tuple[str], mini_batches: int = 1, sequence_length: int = 1) -> List[List[Union[np.ndarray, jax.Array]]]: + def sample_all( + self, names: Tuple[str], mini_batches: int = 1, sequence_length: int = 1 + ) -> List[List[Union[np.ndarray, jax.Array]]]: """Sample all data from memory :param names: Tensors names from which to obtain the samples @@ -426,12 +452,16 @@ def save(self, directory: str = "", format: str = "pt") -> None: if not directory: directory = self.export_directory os.makedirs(os.path.join(directory, "memories"), exist_ok=True) - memory_path = os.path.join(directory, "memories", \ - "{}_memory_{}.{}".format(datetime.datetime.now().strftime("%y-%m-%d_%H-%M-%S-%f"), hex(id(self)), format)) + memory_path = os.path.join( + directory, + "memories", + "{}_memory_{}.{}".format(datetime.datetime.now().strftime("%y-%m-%d_%H-%M-%S-%f"), hex(id(self)), format), + ) # torch if format == "pt": import torch + torch.save({name: self.tensors[name] for name in self.get_tensor_names()}, memory_path) # numpy elif format == "npz": @@ -447,7 +477,16 @@ def save(self, directory: str = "", format: str = "pt") -> None: writer.writerow([item for sublist in headers for item in sublist]) # write rows for i in range(len(self)): - writer.writerow(functools.reduce(operator.iconcat, [self.tensors[name].reshape(-1, self.tensors[name].shape[-1])[i].tolist() for name in names], [])) + writer.writerow( + functools.reduce( + operator.iconcat, + [ + self.tensors[name].reshape(-1, self.tensors[name].shape[-1])[i].tolist() + for name in names + ], + [], + ) + ) # unsupported format else: raise ValueError(f"Unsupported format: {format}. Available formats: pt, csv, npz") @@ -468,6 +507,7 @@ def load(self, path: str) -> None: # torch if path.endswith(".pt"): import torch + data = torch.load(path) for name in self.get_tensor_names(): setattr(self, f"_tensor_{name}", jnp.array(data[name].cpu().numpy())) diff --git a/skrl/memories/jax/random.py b/skrl/memories/jax/random.py index a301c3b3..b5340d46 100644 --- a/skrl/memories/jax/random.py +++ b/skrl/memories/jax/random.py @@ -7,14 +7,16 @@ class RandomMemory(Memory): - def __init__(self, - memory_size: int, - num_envs: int = 1, - device: Optional[jax.Device] = None, - export: bool = False, - export_format: str = "pt", - export_directory: str = "", - replacement=True) -> None: + def __init__( + self, + memory_size: int, + num_envs: int = 1, + device: Optional[jax.Device] = None, + export: bool = False, + export_format: str = "pt", + export_directory: str = "", + replacement=True, + ) -> None: """Random sampling memory Sample a batch from memory randomly diff --git a/skrl/memories/torch/base.py b/skrl/memories/torch/base.py index acc5d83a..b239eb73 100644 --- a/skrl/memories/torch/base.py +++ b/skrl/memories/torch/base.py @@ -15,13 +15,15 @@ class Memory: - def __init__(self, - memory_size: int, - num_envs: int = 1, - device: Optional[Union[str, torch.device]] = None, - export: bool = False, - export_format: str = "pt", - export_directory: str = "") -> None: + def __init__( + self, + memory_size: int, + num_envs: int = 1, + device: Optional[Union[str, torch.device]] = None, + export: bool = False, + export_format: str = "pt", + export_directory: str = "", + ) -> None: """Base class representing a memory with circular buffers Buffers are torch tensors with shape (memory size, number of environments, data size). @@ -48,7 +50,9 @@ def __init__(self, """ self.memory_size = memory_size self.num_envs = num_envs - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device) + self.device = ( + torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device) + ) # internal variables self.filled = False @@ -60,7 +64,9 @@ def __init__(self, self.tensors_keep_dimensions = {} self.sampling_indexes = None - self.all_sequence_indexes = np.concatenate([np.arange(i, memory_size * num_envs + i, num_envs) for i in range(num_envs)]) + self.all_sequence_indexes = np.concatenate( + [np.arange(i, memory_size * num_envs + i, num_envs) for i in range(num_envs)] + ) # exporting data self.export = export @@ -82,8 +88,7 @@ def __len__(self) -> int: return self.memory_size * self.num_envs if self.filled else self.memory_index * self.num_envs + self.env_index def share_memory(self) -> None: - """Share the tensors between processes - """ + """Share the tensors between processes""" for tensor in self.tensors.values(): if not tensor.is_cuda: tensor.share_memory_() @@ -125,11 +130,13 @@ def set_tensor_by_name(self, name: str, tensor: torch.Tensor) -> None: with torch.no_grad(): self.tensors[name].copy_(tensor) - def create_tensor(self, - name: str, - size: Union[int, Tuple[int], gymnasium.Space], - dtype: Optional[torch.dtype] = None, - keep_dimensions: bool = False) -> bool: + def create_tensor( + self, + name: str, + size: Union[int, Tuple[int], gymnasium.Space], + dtype: Optional[torch.dtype] = None, + keep_dimensions: bool = False, + ) -> bool: """Create a new internal tensor in memory The tensor will have a 3-components shape (memory size, number of environments, size). @@ -162,7 +169,9 @@ def create_tensor(self, raise ValueError(f"Dtype of tensor {name} ({dtype}) doesn't match the existing one ({tensor.dtype})") return False # define tensor shape - tensor_shape = (self.memory_size, self.num_envs, *size) if keep_dimensions else (self.memory_size, self.num_envs, size) + tensor_shape = ( + (self.memory_size, self.num_envs, *size) if keep_dimensions else (self.memory_size, self.num_envs, size) + ) view_shape = (-1, *size) if keep_dimensions else (-1, size) # create tensor (_tensor_) and add it to the internal storage setattr(self, f"_tensor_{name}", torch.zeros(tensor_shape, device=self.device, dtype=dtype)) @@ -215,7 +224,9 @@ def add_samples(self, **tensors: torch.Tensor) -> None: :raises ValueError: No tensors were provided or the tensors have incompatible shapes """ if not tensors: - raise ValueError("No samples to be recorded in memory. Pass samples as key-value arguments (where key is the tensor name)") + raise ValueError( + "No samples to be recorded in memory. Pass samples as key-value arguments (where key is the tensor name)" + ) # dimensions and shapes of the tensors (assume all tensors have the dimensions of the first tensor) tmp = tensors.get("states", tensors[next(iter(tensors))]) # ask for states first @@ -231,7 +242,9 @@ def add_samples(self, **tensors: torch.Tensor) -> None: elif dim > 1 and shape[0] < self.num_envs: for name, tensor in tensors.items(): if name in self.tensors: - self.tensors[name][self.memory_index, self.env_index:self.env_index + tensor.shape[0]].copy_(tensor) + self.tensors[name][self.memory_index, self.env_index : self.env_index + tensor.shape[0]].copy_( + tensor + ) self.env_index += tensor.shape[0] # single environment - multi sample (number of environments greater than num_envs (num_envs = 1)) elif dim > 1 and self.num_envs == 1: @@ -240,7 +253,9 @@ def add_samples(self, **tensors: torch.Tensor) -> None: num_samples = min(shape[0], self.memory_size - self.memory_index) remaining_samples = shape[0] - num_samples # copy the first n samples - self.tensors[name][self.memory_index:self.memory_index + num_samples].copy_(tensor[:num_samples].unsqueeze(dim=1)) + self.tensors[name][self.memory_index : self.memory_index + num_samples].copy_( + tensor[:num_samples].unsqueeze(dim=1) + ) self.memory_index += num_samples # storage remaining samples if remaining_samples > 0: @@ -267,11 +282,9 @@ def add_samples(self, **tensors: torch.Tensor) -> None: if self.export: self.save(directory=self.export_directory, format=self.export_format) - def sample(self, - names: Tuple[str], - batch_size: int, - mini_batches: int = 1, - sequence_length: int = 1) -> List[List[torch.Tensor]]: + def sample( + self, names: Tuple[str], batch_size: int, mini_batches: int = 1, sequence_length: int = 1 + ) -> List[List[torch.Tensor]]: """Data sampling method to be implemented by the inheriting classes :param names: Tensors names from which to obtain the samples @@ -291,7 +304,9 @@ def sample(self, """ raise NotImplementedError("The sampling method (.sample()) is not implemented") - def sample_by_index(self, names: Tuple[str], indexes: Union[tuple, np.ndarray, torch.Tensor], mini_batches: int = 1) -> List[List[torch.Tensor]]: + def sample_by_index( + self, names: Tuple[str], indexes: Union[tuple, np.ndarray, torch.Tensor], mini_batches: int = 1 + ) -> List[List[torch.Tensor]]: """Sample data from memory according to their indexes :param names: Tensors names from which to obtain the samples @@ -310,7 +325,9 @@ def sample_by_index(self, names: Tuple[str], indexes: Union[tuple, np.ndarray, t return [[self.tensors_view[name][batch] for name in names] for batch in batches] return [[self.tensors_view[name][indexes] for name in names]] - def sample_all(self, names: Tuple[str], mini_batches: int = 1, sequence_length: int = 1) -> List[List[torch.Tensor]]: + def sample_all( + self, names: Tuple[str], mini_batches: int = 1, sequence_length: int = 1 + ) -> List[List[torch.Tensor]]: """Sample all data from memory :param names: Tensors names from which to obtain the samples @@ -327,7 +344,9 @@ def sample_all(self, names: Tuple[str], mini_batches: int = 1, sequence_length: # sequential order if sequence_length > 1: if mini_batches > 1: - batches = BatchSampler(self.all_sequence_indexes, batch_size=len(self.all_sequence_indexes) // mini_batches, drop_last=True) + batches = BatchSampler( + self.all_sequence_indexes, batch_size=len(self.all_sequence_indexes) // mini_batches, drop_last=True + ) return [[self.tensors_view[name][batch] for name in names] for batch in batches] return [[self.tensors_view[name][self.all_sequence_indexes] for name in names]] @@ -366,8 +385,11 @@ def save(self, directory: str = "", format: str = "pt") -> None: if not directory: directory = self.export_directory os.makedirs(os.path.join(directory, "memories"), exist_ok=True) - memory_path = os.path.join(directory, "memories", \ - "{}_memory_{}.{}".format(datetime.datetime.now().strftime("%y-%m-%d_%H-%M-%S-%f"), hex(id(self)), format)) + memory_path = os.path.join( + directory, + "memories", + "{}_memory_{}.{}".format(datetime.datetime.now().strftime("%y-%m-%d_%H-%M-%S-%f"), hex(id(self)), format), + ) # torch if format == "pt": @@ -386,7 +408,9 @@ def save(self, directory: str = "", format: str = "pt") -> None: writer.writerow([item for sublist in headers for item in sublist]) # write rows for i in range(len(self)): - writer.writerow(functools.reduce(operator.iconcat, [self.tensors_view[name][i].tolist() for name in names], [])) + writer.writerow( + functools.reduce(operator.iconcat, [self.tensors_view[name][i].tolist() for name in names], []) + ) # unsupported format else: raise ValueError(f"Unsupported format: {format}. Available formats: pt, csv, npz") diff --git a/skrl/memories/torch/random.py b/skrl/memories/torch/random.py index e20430b9..50700a76 100644 --- a/skrl/memories/torch/random.py +++ b/skrl/memories/torch/random.py @@ -6,14 +6,16 @@ class RandomMemory(Memory): - def __init__(self, - memory_size: int, - num_envs: int = 1, - device: Optional[Union[str, torch.device]] = None, - export: bool = False, - export_format: str = "pt", - export_directory: str = "", - replacement=True) -> None: + def __init__( + self, + memory_size: int, + num_envs: int = 1, + device: Optional[Union[str, torch.device]] = None, + export: bool = False, + export_format: str = "pt", + export_directory: str = "", + replacement=True, + ) -> None: """Random sampling memory Sample a batch from memory randomly @@ -45,11 +47,9 @@ def __init__(self, self._replacement = replacement - def sample(self, - names: Tuple[str], - batch_size: int, - mini_batches: int = 1, - sequence_length: int = 1) -> List[List[torch.Tensor]]: + def sample( + self, names: Tuple[str], batch_size: int, mini_batches: int = 1, sequence_length: int = 1 + ) -> List[List[torch.Tensor]]: """Sample a batch from memory randomly :param names: Tensors names from which to obtain the samples diff --git a/skrl/models/jax/base.py b/skrl/models/jax/base.py index 9e68954a..0d8615cd 100644 --- a/skrl/models/jax/base.py +++ b/skrl/models/jax/base.py @@ -15,11 +15,12 @@ def _vectorize_leaves(leaves: Sequence[jax.Array]) -> jax.Array: return jnp.expand_dims(jnp.concatenate(list(map(jnp.ravel, leaves)), axis=-1), 0) + @jax.jit def _unvectorize_leaves(leaves: Sequence[jax.Array], vector: jax.Array) -> Sequence[jax.Array]: offset = 0 for i, leaf in enumerate(leaves): - leaves[i] = leaves[i].at[:].set(vector.at[0, offset:offset + leaf.size].get().reshape(leaf.shape)) + leaves[i] = leaves[i].at[:].set(vector.at[0, offset : offset + leaf.size].get().reshape(leaf.shape)) offset += leaf.size return leaves @@ -38,12 +39,14 @@ class Model(flax.linen.Module): action_space: Union[int, Sequence[int], gymnasium.Space] device: Optional[Union[str, jax.Device]] = None - def __init__(self, - observation_space: Union[int, Sequence[int], gymnasium.Space], - action_space: Union[int, Sequence[int], gymnasium.Space], - device: Optional[Union[str, jax.Device]] = None, - parent: Optional[Any] = None, - name: Optional[str] = None) -> None: + def __init__( + self, + observation_space: Union[int, Sequence[int], gymnasium.Space], + action_space: Union[int, Sequence[int], gymnasium.Space], + device: Optional[Union[str, jax.Device]] = None, + parent: Optional[Any] = None, + name: Optional[str] = None, + ) -> None: """Base class representing a function approximator The following properties are defined: @@ -95,7 +98,7 @@ def __call__(self, inputs, role): else: self.device = device if type(device) == str: - device_type, device_index = f"{device}:0".split(':')[:2] + device_type, device_index = f"{device}:0".split(":")[:2] self.device = jax.devices(device_type)[int(device_index)] self.observation_space = observation_space @@ -110,10 +113,9 @@ def __call__(self, inputs, role): self.parent = parent self.name = name - def init_state_dict(self, - role: str, - inputs: Mapping[str, Union[np.ndarray, jax.Array]] = {}, - key: Optional[jax.Array] = None) -> None: + def init_state_dict( + self, role: str, inputs: Mapping[str, Union[np.ndarray, jax.Array]] = {}, key: Optional[jax.Array] = None + ) -> None: """Initialize state dictionary :param role: Role play by the model @@ -134,15 +136,14 @@ def init_state_dict(self, if key is None: key = config.jax.key if isinstance(inputs["states"], (int, np.int32, np.int64)): - inputs["states"] = np.array(inputs["states"]).reshape(-1,1) + inputs["states"] = np.array(inputs["states"]).reshape(-1, 1) # init internal state dict with jax.default_device(self.device): self.state_dict = StateDict.create(apply_fn=self.apply, params=self.init(key, inputs, role)) - def tensor_to_space(self, - tensor: Union[np.ndarray, jax.Array], - space: gymnasium.Space, - start: int = 0) -> Union[Union[np.ndarray, jax.Array], dict]: + def tensor_to_space( + self, tensor: Union[np.ndarray, jax.Array], space: gymnasium.Space, start: int = 0 + ) -> Union[Union[np.ndarray, jax.Array], dict]: """Map a flat tensor to a Gym/Gymnasium space .. warning:: @@ -174,10 +175,16 @@ def tensor_to_space(self, """ return unflatten_tensorized_space(space, tensor) - def random_act(self, - inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], - role: str = "", - params: Optional[jax.Array] = None) -> Tuple[Union[np.ndarray, jax.Array], Union[Union[np.ndarray, jax.Array], None], Mapping[str, Union[Union[np.ndarray, jax.Array], Any]]]: + def random_act( + self, + inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], + role: str = "", + params: Optional[jax.Array] = None, + ) -> Tuple[ + Union[np.ndarray, jax.Array], + Union[Union[np.ndarray, jax.Array], None], + Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], + ]: """Act randomly according to the action space :param inputs: Model inputs. The most common keys are: @@ -201,7 +208,11 @@ def random_act(self, actions = np.random.randint(self.action_space.n, size=(inputs["states"].shape[0], 1)) # continuous action space (Box) elif isinstance(self.action_space, gymnasium.spaces.Box): - actions = np.random.uniform(low=self.action_space.low[0], high=self.action_space.high[0], size=(inputs["states"].shape[0], self.num_actions)) + actions = np.random.uniform( + low=self.action_space.low[0], + high=self.action_space.high[0], + size=(inputs["states"].shape[0], self.num_actions), + ) else: raise NotImplementedError(f"Action space type ({type(self.action_space)}) not supported") @@ -262,8 +273,10 @@ def init_weights(self, method_name: str = "normal", *args, **kwargs) -> None: method = eval(f"flax.linen.initializers.{method_name}") else: method = eval(f"flax.linen.initializers.{method_name}(*args, **kwargs)") - params = jax.tree_util.tree_map_with_path(lambda path, param: method(config.jax.key, param.shape) if path[-1].key == "kernel" else param, - self.state_dict.params) + params = jax.tree_util.tree_map_with_path( + lambda path, param: method(config.jax.key, param.shape) if path[-1].key == "kernel" else param, + self.state_dict.params, + ) self.state_dict = self.state_dict.replace(params=params) def init_biases(self, method_name: str = "constant_", *args, **kwargs) -> None: @@ -291,8 +304,10 @@ def init_biases(self, method_name: str = "constant_", *args, **kwargs) -> None: method = eval(f"flax.linen.initializers.{method_name}") else: method = eval(f"flax.linen.initializers.{method_name}(*args, **kwargs)") - params = jax.tree_util.tree_map_with_path(lambda path, param: method(config.jax.key, param.shape) if path[-1].key == "bias" else param, - self.state_dict.params) + params = jax.tree_util.tree_map_with_path( + lambda path, param: method(config.jax.key, param.shape) if path[-1].key == "bias" else param, + self.state_dict.params, + ) self.state_dict = self.state_dict.replace(params=params) def get_specification(self) -> Mapping[str, Any]: @@ -319,10 +334,12 @@ def get_specification(self) -> Mapping[str, Any]: """ return {} - def act(self, - inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], - role: str = "", - params: Optional[jax.Array] = None) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: + def act( + self, + inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], + role: str = "", + params: Optional[jax.Array] = None, + ) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: """Act according to the specified behavior (to be implemented by the inheriting classes) Agents will call this method to obtain the decision to be taken given the state of the environment. @@ -401,12 +418,14 @@ def load(self, path: str) -> None: self.state_dict = self.state_dict.replace(params=params) self.set_mode("eval") - def migrate(self, - state_dict: Optional[Mapping[str, Any]] = None, - path: Optional[str] = None, - name_map: Mapping[str, str] = {}, - auto_mapping: bool = True, - verbose: bool = False) -> bool: + def migrate( + self, + state_dict: Optional[Mapping[str, Any]] = None, + path: Optional[str] = None, + name_map: Mapping[str, str] = {}, + auto_mapping: bool = True, + verbose: bool = False, + ) -> bool: """Migrate the specified external model's state dict to the current model .. warning:: @@ -463,8 +482,11 @@ def update_parameters(self, model: flax.linen.Module, polyak: float = 1) -> None # soft update else: # HACK: Does it make sense to use https://optax.readthedocs.io/en/latest/api.html?#optax.incremental_update - params = jax.tree_util.tree_map(lambda params, model_params: polyak * model_params + (1 - polyak) * params, - self.state_dict.params, model.state_dict.params) + params = jax.tree_util.tree_map( + lambda params, model_params: polyak * model_params + (1 - polyak) * params, + self.state_dict.params, + model.state_dict.params, + ) self.state_dict = self.state_dict.replace(params=params) def broadcast_parameters(self, rank: int = 0): @@ -512,5 +534,7 @@ def reduce_parameters(self, tree: Any) -> Any: # return unflatten(jnp.squeeze(vector, 0)) leaves, treedef = jax.tree.flatten(tree) - vector = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(_vectorize_leaves(leaves)) / config.jax.world_size + vector = ( + jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(_vectorize_leaves(leaves)) / config.jax.world_size + ) return jax.tree.unflatten(treedef, _unvectorize_leaves(leaves, vector)) diff --git a/skrl/models/jax/categorical.py b/skrl/models/jax/categorical.py index 14ad316a..8e8e43a1 100644 --- a/skrl/models/jax/categorical.py +++ b/skrl/models/jax/categorical.py @@ -12,10 +12,7 @@ # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @partial(jax.jit, static_argnames=("unnormalized_log_prob")) -def _categorical(net_output, - unnormalized_log_prob, - taken_actions, - key): +def _categorical(net_output, unnormalized_log_prob, taken_actions, key): # normalize if unnormalized_log_prob: logits = net_output - jax.scipy.special.logsumexp(net_output, axis=-1, keepdims=True) @@ -34,6 +31,7 @@ def _categorical(net_output, return actions.reshape(-1, 1), log_prob.reshape(-1, 1) + @jax.jit def _entropy(logits): logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) @@ -92,10 +90,12 @@ def __init__(self, unnormalized_log_prob: bool = True, role: str = "") -> None: # https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.IncorrectPostInitOverrideError flax.linen.Module.__post_init__(self) - def act(self, - inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], - role: str = "", - params: Optional[jax.Array] = None) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: + def act( + self, + inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], + role: str = "", + params: Optional[jax.Array] = None, + ) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: """Act stochastically in response to the state of the environment :param inputs: Model inputs. The most common keys are: @@ -130,10 +130,9 @@ def act(self, # map from states/observations to normalized probabilities or unnormalized log probabilities net_output, outputs = self.apply(self.state_dict.params if params is None else params, inputs, role) - actions, log_prob = _categorical(net_output, - self._unnormalized_log_prob, - inputs.get("taken_actions", None), - subkey) + actions, log_prob = _categorical( + net_output, self._unnormalized_log_prob, inputs.get("taken_actions", None), subkey + ) outputs["net_output"] = net_output # avoid jax.errors.UnexpectedTracerError diff --git a/skrl/models/jax/deterministic.py b/skrl/models/jax/deterministic.py index 407251de..64fb72e1 100644 --- a/skrl/models/jax/deterministic.py +++ b/skrl/models/jax/deterministic.py @@ -58,10 +58,12 @@ def __init__(self, clip_actions: bool = False, role: str = "") -> None: # https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.IncorrectPostInitOverrideError flax.linen.Module.__post_init__(self) - def act(self, - inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], - role: str = "", - params: Optional[jax.Array] = None) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: + def act( + self, + inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], + role: str = "", + params: Optional[jax.Array] = None, + ) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: """Act deterministically in response to the state of the environment :param inputs: Model inputs. The most common keys are: diff --git a/skrl/models/jax/gaussian.py b/skrl/models/jax/gaussian.py index e9783be6..ab80a7b4 100644 --- a/skrl/models/jax/gaussian.py +++ b/skrl/models/jax/gaussian.py @@ -13,15 +13,9 @@ # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @partial(jax.jit, static_argnames=("reduction")) -def _gaussian(loc, - log_std, - log_std_min, - log_std_max, - clip_actions_min, - clip_actions_max, - taken_actions, - key, - reduction): +def _gaussian( + loc, log_std, log_std_min, log_std_max, clip_actions_min, clip_actions_max, taken_actions, key, reduction +): # clamp log standard deviations log_std = jnp.clip(log_std, a_min=log_std_min, a_max=log_std_max) @@ -45,19 +39,22 @@ def _gaussian(loc, return actions, log_prob, log_std, scale + @jax.jit def _entropy(scale): return 0.5 + 0.5 * jnp.log(2 * jnp.pi) + jnp.log(scale) class GaussianMixin: - def __init__(self, - clip_actions: bool = False, - clip_log_std: bool = True, - min_log_std: float = -20, - max_log_std: float = 2, - reduction: str = "sum", - role: str = "") -> None: + def __init__( + self, + clip_actions: bool = False, + clip_log_std: bool = True, + min_log_std: float = -20, + max_log_std: float = 2, + reduction: str = "sum", + role: str = "", + ) -> None: """Gaussian mixin model (stochastic model) :param clip_actions: Flag to indicate whether the actions should be clipped to the action space (default: ``False``) @@ -132,8 +129,11 @@ def __init__(self, if reduction not in ["mean", "sum", "prod", "none"]: raise ValueError("reduction must be one of 'mean', 'sum', 'prod' or 'none'") - self._reduction = jnp.mean if reduction == "mean" else jnp.sum if reduction == "sum" \ - else jnp.prod if reduction == "prod" else None + self._reduction = ( + jnp.mean + if reduction == "mean" + else jnp.sum if reduction == "sum" else jnp.prod if reduction == "prod" else None + ) self._i = 0 self._key = config.jax.key @@ -141,10 +141,12 @@ def __init__(self, # https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.IncorrectPostInitOverrideError flax.linen.Module.__post_init__(self) - def act(self, - inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], - role: str = "", - params: Optional[jax.Array] = None) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: + def act( + self, + inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], + role: str = "", + params: Optional[jax.Array] = None, + ) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: """Act stochastically in response to the state of the environment :param inputs: Model inputs. The most common keys are: @@ -179,15 +181,17 @@ def act(self, # map from states/observations to mean actions and log standard deviations mean_actions, log_std, outputs = self.apply(self.state_dict.params if params is None else params, inputs, role) - actions, log_prob, log_std, stddev = _gaussian(mean_actions, - log_std, - self._log_std_min, - self._log_std_max, - self.clip_actions_min, - self.clip_actions_max, - inputs.get("taken_actions", None), - subkey, - self._reduction) + actions, log_prob, log_std, stddev = _gaussian( + mean_actions, + log_std, + self._log_std_min, + self._log_std_max, + self.clip_actions_min, + self.clip_actions_max, + inputs.get("taken_actions", None), + subkey, + self._reduction, + ) outputs["mean_actions"] = mean_actions # avoid jax.errors.UnexpectedTracerError diff --git a/skrl/models/jax/multicategorical.py b/skrl/models/jax/multicategorical.py index adb2a95c..cb818513 100644 --- a/skrl/models/jax/multicategorical.py +++ b/skrl/models/jax/multicategorical.py @@ -12,10 +12,7 @@ # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @partial(jax.jit, static_argnames=("unnormalized_log_prob")) -def _categorical(net_output, - unnormalized_log_prob, - taken_actions, - key): +def _categorical(net_output, unnormalized_log_prob, taken_actions, key): # normalize if unnormalized_log_prob: logits = net_output - jax.scipy.special.logsumexp(net_output, axis=-1, keepdims=True) @@ -34,6 +31,7 @@ def _categorical(net_output, return actions.reshape(-1, 1), log_prob.reshape(-1, 1) + @jax.jit def _entropy(logits): logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) @@ -94,8 +92,11 @@ def __init__(self, unnormalized_log_prob: bool = True, reduction: str = "sum", r if reduction not in ["mean", "sum", "prod", "none"]: raise ValueError("reduction must be one of 'mean', 'sum', 'prod' or 'none'") - self._reduction = jnp.mean if reduction == "mean" else jnp.sum if reduction == "sum" \ - else jnp.prod if reduction == "prod" else None + self._reduction = ( + jnp.mean + if reduction == "mean" + else jnp.sum if reduction == "sum" else jnp.prod if reduction == "prod" else None + ) self._i = 0 self._key = config.jax.key @@ -106,10 +107,12 @@ def __init__(self, unnormalized_log_prob: bool = True, reduction: str = "sum", r # https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.IncorrectPostInitOverrideError flax.linen.Module.__post_init__(self) - def act(self, - inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], - role: str = "", - params: Optional[jax.Array] = None) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: + def act( + self, + inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], + role: str = "", + params: Optional[jax.Array] = None, + ) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: """Act stochastically in response to the state of the environment :param inputs: Model inputs. The most common keys are: @@ -154,10 +157,7 @@ def act(self, # compute actions and log_prob actions, log_prob = [], [] for _net_output, _taken_actions in zip(net_outputs, taken_actions): - _actions, _log_prob = _categorical(_net_output, - self._unnormalized_log_prob, - _taken_actions, - subkey) + _actions, _log_prob = _categorical(_net_output, self._unnormalized_log_prob, _taken_actions, subkey) actions.append(_actions) log_prob.append(_log_prob) diff --git a/skrl/models/torch/base.py b/skrl/models/torch/base.py index 39e746b4..a62f2522 100644 --- a/skrl/models/torch/base.py +++ b/skrl/models/torch/base.py @@ -11,10 +11,12 @@ class Model(torch.nn.Module): - def __init__(self, - observation_space: Union[int, Sequence[int], gymnasium.Space], - action_space: Union[int, Sequence[int], gymnasium.Space], - device: Optional[Union[str, torch.device]] = None) -> None: + def __init__( + self, + observation_space: Union[int, Sequence[int], gymnasium.Space], + action_space: Union[int, Sequence[int], gymnasium.Space], + device: Optional[Union[str, torch.device]] = None, + ) -> None: """Base class representing a function approximator The following properties are defined: @@ -54,7 +56,9 @@ def act(self, inputs, role=""): """ super(Model, self).__init__() - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device) + self.device = ( + torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device) + ) self.observation_space = observation_space self.action_space = action_space @@ -63,10 +67,9 @@ def act(self, inputs, role=""): self._random_distribution = None - def tensor_to_space(self, - tensor: torch.Tensor, - space: gymnasium.Space, - start: int = 0) -> Union[torch.Tensor, dict]: + def tensor_to_space( + self, tensor: torch.Tensor, space: gymnasium.Space, start: int = 0 + ) -> Union[torch.Tensor, dict]: """Map a flat tensor to a Gym/Gymnasium space .. warning:: @@ -98,9 +101,9 @@ def tensor_to_space(self, """ return unflatten_tensorized_space(space, tensor) - def random_act(self, - inputs: Mapping[str, Union[torch.Tensor, Any]], - role: str = "") -> Tuple[torch.Tensor, None, Mapping[str, Union[torch.Tensor, Any]]]: + def random_act( + self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "" + ) -> Tuple[torch.Tensor, None, Mapping[str, Union[torch.Tensor, Any]]]: """Act randomly according to the action space :param inputs: Model inputs. The most common keys are: @@ -124,9 +127,14 @@ def random_act(self, if self._random_distribution is None: self._random_distribution = torch.distributions.uniform.Uniform( low=torch.tensor(self.action_space.low[0], device=self.device, dtype=torch.float32), - high=torch.tensor(self.action_space.high[0], device=self.device, dtype=torch.float32)) - - return self._random_distribution.sample(sample_shape=(inputs["states"].shape[0], self.num_actions)), None, {} + high=torch.tensor(self.action_space.high[0], device=self.device, dtype=torch.float32), + ) + + return ( + self._random_distribution.sample(sample_shape=(inputs["states"].shape[0], self.num_actions)), + None, + {}, + ) else: raise NotImplementedError(f"Action space type ({type(self.action_space)}) not supported") @@ -178,6 +186,7 @@ def init_weights(self, method_name: str = "orthogonal_", *args, **kwargs) -> Non # initialize all weights with normal distribution with mean 0 and standard deviation 0.25 >>> model.init_weights(method_name="normal_", mean=0.0, std=0.25) """ + def _update_weights(module, method_name, args, kwargs): for layer in module: if isinstance(layer, torch.nn.Sequential): @@ -211,6 +220,7 @@ def init_biases(self, method_name: str = "constant_", *args, **kwargs) -> None: # initialize all biases with normal distribution with mean 0 and standard deviation 0.25 >>> model.init_biases(method_name="normal_", mean=0.0, std=0.25) """ + def _update_biases(module, method_name, args, kwargs): for layer in module: if isinstance(layer, torch.nn.Sequential): @@ -244,9 +254,9 @@ def get_specification(self) -> Mapping[str, Any]: """ return {} - def forward(self, - inputs: Mapping[str, Union[torch.Tensor, Any]], - role: str = "") -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: + def forward( + self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "" + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: """Forward pass of the model This method calls the ``.act()`` method and returns its outputs @@ -266,9 +276,9 @@ def forward(self, """ return self.act(inputs, role) - def compute(self, - inputs: Mapping[str, Union[torch.Tensor, Any]], - role: str = "") -> Tuple[Union[torch.Tensor, Mapping[str, Union[torch.Tensor, Any]]]]: + def compute( + self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "" + ) -> Tuple[Union[torch.Tensor, Mapping[str, Union[torch.Tensor, Any]]]]: """Define the computation performed (to be implemented by the inheriting classes) by the models :param inputs: Model inputs. The most common keys are: @@ -286,9 +296,9 @@ def compute(self, """ raise NotImplementedError("The computation performed by the models (.compute()) is not implemented") - def act(self, - inputs: Mapping[str, Union[torch.Tensor, Any]], - role: str = "") -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: + def act( + self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "" + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: """Act according to the specified behavior (to be implemented by the inheriting classes) Agents will call this method to obtain the decision to be taken given the state of the environment. @@ -375,12 +385,14 @@ def load(self, path: str) -> None: self.load_state_dict(state_dict) self.eval() - def migrate(self, - state_dict: Optional[Mapping[str, torch.Tensor]] = None, - path: Optional[str] = None, - name_map: Mapping[str, str] = {}, - auto_mapping: bool = True, - verbose: bool = False) -> bool: + def migrate( + self, + state_dict: Optional[Mapping[str, torch.Tensor]] = None, + path: Optional[str] = None, + name_map: Mapping[str, str] = {}, + auto_mapping: bool = True, + verbose: bool = False, + ) -> bool: """Migrate the specified extrernal model's state dict to the current model The final storage device is determined by the constructor of the model @@ -491,9 +503,10 @@ def migrate(self, # stable-baselines3 elif path.endswith(".zip"): import zipfile + try: - archive = zipfile.ZipFile(path, 'r') - with archive.open('policy.pth', mode="r") as file: + archive = zipfile.ZipFile(path, "r") + with archive.open("policy.pth", mode="r") as file: state_dict = torch.load(file, map_location=self.device) except KeyError as e: logger.warning(str(e)) @@ -528,7 +541,9 @@ def migrate(self, logger.info(f" |-- map: {name} <- {external_name}") break else: - logger.warning(f"Shape mismatch for {name} <- {external_name} : {tensor.shape} != {external_tensor.shape}") + logger.warning( + f"Shape mismatch for {name} <- {external_name} : {tensor.shape} != {external_tensor.shape}" + ) # auto-mapped names if auto_mapping and name not in name_map: if tensor.shape == external_tensor.shape: @@ -669,6 +684,8 @@ def reduce_parameters(self): offset = 0 for parameters in self.parameters(): if parameters.grad is not None: - parameters.grad.data.copy_(gradients[offset:offset + parameters.numel()] \ - .view_as(parameters.grad.data) / config.torch.world_size) + parameters.grad.data.copy_( + gradients[offset : offset + parameters.numel()].view_as(parameters.grad.data) + / config.torch.world_size + ) offset += parameters.numel() diff --git a/skrl/models/torch/categorical.py b/skrl/models/torch/categorical.py index 7f52f1b3..65bdd35b 100644 --- a/skrl/models/torch/categorical.py +++ b/skrl/models/torch/categorical.py @@ -55,9 +55,9 @@ def __init__(self, unnormalized_log_prob: bool = True, role: str = "") -> None: self._unnormalized_log_prob = unnormalized_log_prob self._distribution = None - def act(self, - inputs: Mapping[str, Union[torch.Tensor, Any]], - role: str = "") -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: + def act( + self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "" + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: """Act stochastically in response to the state of the environment :param inputs: Model inputs. The most common keys are: diff --git a/skrl/models/torch/deterministic.py b/skrl/models/torch/deterministic.py index 8dd52d45..580511b1 100644 --- a/skrl/models/torch/deterministic.py +++ b/skrl/models/torch/deterministic.py @@ -56,9 +56,9 @@ def __init__(self, clip_actions: bool = False, role: str = "") -> None: self._clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32) self._clip_actions_max = torch.tensor(self.action_space.high, device=self.device, dtype=torch.float32) - def act(self, - inputs: Mapping[str, Union[torch.Tensor, Any]], - role: str = "") -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: + def act( + self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "" + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: """Act deterministically in response to the state of the environment :param inputs: Model inputs. The most common keys are: diff --git a/skrl/models/torch/gaussian.py b/skrl/models/torch/gaussian.py index 6b569cca..23cea719 100644 --- a/skrl/models/torch/gaussian.py +++ b/skrl/models/torch/gaussian.py @@ -7,13 +7,15 @@ class GaussianMixin: - def __init__(self, - clip_actions: bool = False, - clip_log_std: bool = True, - min_log_std: float = -20, - max_log_std: float = 2, - reduction: str = "sum", - role: str = "") -> None: + def __init__( + self, + clip_actions: bool = False, + clip_log_std: bool = True, + min_log_std: float = -20, + max_log_std: float = 2, + reduction: str = "sum", + role: str = "", + ) -> None: """Gaussian mixin model (stochastic model) :param clip_actions: Flag to indicate whether the actions should be clipped to the action space (default: ``False``) @@ -87,12 +89,15 @@ def __init__(self, if reduction not in ["mean", "sum", "prod", "none"]: raise ValueError("reduction must be one of 'mean', 'sum', 'prod' or 'none'") - self._reduction = torch.mean if reduction == "mean" else torch.sum if reduction == "sum" \ - else torch.prod if reduction == "prod" else None - - def act(self, - inputs: Mapping[str, Union[torch.Tensor, Any]], - role: str = "") -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: + self._reduction = ( + torch.mean + if reduction == "mean" + else torch.sum if reduction == "sum" else torch.prod if reduction == "prod" else None + ) + + def act( + self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "" + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: """Act stochastically in response to the state of the environment :param inputs: Model inputs. The most common keys are: diff --git a/skrl/models/torch/multicategorical.py b/skrl/models/torch/multicategorical.py index 3ed95e8b..0e60d091 100644 --- a/skrl/models/torch/multicategorical.py +++ b/skrl/models/torch/multicategorical.py @@ -63,12 +63,15 @@ def __init__(self, unnormalized_log_prob: bool = True, reduction: str = "sum", r if reduction not in ["mean", "sum", "prod", "none"]: raise ValueError("reduction must be one of 'mean', 'sum', 'prod' or 'none'") - self._reduction = torch.mean if reduction == "mean" else torch.sum if reduction == "sum" \ - else torch.prod if reduction == "prod" else None - - def act(self, - inputs: Mapping[str, Union[torch.Tensor, Any]], - role: str = "") -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: + self._reduction = ( + torch.mean + if reduction == "mean" + else torch.sum if reduction == "sum" else torch.prod if reduction == "prod" else None + ) + + def act( + self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "" + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: """Act stochastically in response to the state of the environment :param inputs: Model inputs. The most common keys are: @@ -97,17 +100,29 @@ def act(self, # unnormalized log probabilities if self._unnormalized_log_prob: - self._distributions = [Categorical(logits=logits) for logits in torch.split(net_output, self.action_space.nvec.tolist(), dim=-1)] + self._distributions = [ + Categorical(logits=logits) + for logits in torch.split(net_output, self.action_space.nvec.tolist(), dim=-1) + ] # normalized probabilities else: - self._distributions = [Categorical(probs=probs) for probs in torch.split(net_output, self.action_space.nvec.tolist(), dim=-1)] + self._distributions = [ + Categorical(probs=probs) for probs in torch.split(net_output, self.action_space.nvec.tolist(), dim=-1) + ] # actions actions = torch.stack([distribution.sample() for distribution in self._distributions], dim=-1) # log of the probability density function - log_prob = torch.stack([distribution.log_prob(_actions.view(-1)) for _actions, distribution \ - in zip(torch.unbind(inputs.get("taken_actions", actions), dim=-1), self._distributions)], dim=-1) + log_prob = torch.stack( + [ + distribution.log_prob(_actions.view(-1)) + for _actions, distribution in zip( + torch.unbind(inputs.get("taken_actions", actions), dim=-1), self._distributions + ) + ], + dim=-1, + ) if self._reduction is not None: log_prob = self._reduction(log_prob, dim=-1) if log_prob.dim() != actions.dim(): @@ -131,7 +146,9 @@ def get_entropy(self, role: str = "") -> torch.Tensor: torch.Size([4096, 1]) """ if self._distributions: - entropy = torch.stack([distribution.entropy().to(self.device) for distribution in self._distributions], dim=-1) + entropy = torch.stack( + [distribution.entropy().to(self.device) for distribution in self._distributions], dim=-1 + ) if self._reduction is not None: return self._reduction(entropy, dim=-1).unsqueeze(-1) return entropy diff --git a/skrl/models/torch/multivariate_gaussian.py b/skrl/models/torch/multivariate_gaussian.py index 9a66041c..1ae7da16 100644 --- a/skrl/models/torch/multivariate_gaussian.py +++ b/skrl/models/torch/multivariate_gaussian.py @@ -7,12 +7,14 @@ class MultivariateGaussianMixin: - def __init__(self, - clip_actions: bool = False, - clip_log_std: bool = True, - min_log_std: float = -20, - max_log_std: float = 2, - role: str = "") -> None: + def __init__( + self, + clip_actions: bool = False, + clip_log_std: bool = True, + min_log_std: float = -20, + max_log_std: float = 2, + role: str = "", + ) -> None: """Multivariate Gaussian mixin model (stochastic model) :param clip_actions: Flag to indicate whether the actions should be clipped to the action space (default: ``False``) @@ -78,9 +80,9 @@ def __init__(self, self._num_samples = None self._distribution = None - def act(self, - inputs: Mapping[str, Union[torch.Tensor, Any]], - role: str = "") -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: + def act( + self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "" + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: """Act stochastically in response to the state of the environment :param inputs: Model inputs. The most common keys are: diff --git a/skrl/models/torch/tabular.py b/skrl/models/torch/tabular.py index 58afa805..a22e455f 100644 --- a/skrl/models/torch/tabular.py +++ b/skrl/models/torch/tabular.py @@ -47,17 +47,16 @@ def __init__(self, num_envs: int = 1, role: str = "") -> None: self.num_envs = num_envs def __repr__(self) -> str: - """String representation of an object as torch.nn.Module - """ + """String representation of an object as torch.nn.Module""" lines = [] for name in self._get_tensor_names(): tensor = getattr(self, name) lines.append(f"({name}): {tensor.__class__.__name__}(shape={list(tensor.shape)})") - main_str = self.__class__.__name__ + '(' + main_str = self.__class__.__name__ + "(" if lines: main_str += "\n {}\n".format("\n ".join(lines)) - main_str += ')' + main_str += ")" return main_str def _get_tensor_names(self) -> Sequence[str]: @@ -72,9 +71,9 @@ def _get_tensor_names(self) -> Sequence[str]: tensors.append(attr) return sorted(tensors) - def act(self, - inputs: Mapping[str, Union[torch.Tensor, Any]], - role: str = "") -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: + def act( + self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "" + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: """Act in response to the state of the environment :param inputs: Model inputs. The most common keys are: @@ -157,7 +156,9 @@ def load_state_dict(self, state_dict: Mapping, strict: bool = True) -> None: if _tensor.shape == tensor.shape and _tensor.dtype == tensor.dtype: setattr(self, name, tensor) else: - raise ValueError(f"Tensor shape ({_tensor.shape} vs {tensor.shape}) or dtype ({_tensor.dtype} vs {tensor.dtype}) mismatch") + raise ValueError( + f"Tensor shape ({_tensor.shape} vs {tensor.shape}) or dtype ({_tensor.dtype} vs {tensor.dtype}) mismatch" + ) else: raise ValueError(f"{name} is not a tensor of {self.__class__.__name__}") @@ -209,6 +210,8 @@ def load(self, path: str) -> None: if _tensor.shape == tensor.shape and _tensor.dtype == tensor.dtype: setattr(self, name, tensor) else: - raise ValueError(f"Tensor shape ({_tensor.shape} vs {tensor.shape}) or dtype ({_tensor.dtype} vs {tensor.dtype}) mismatch") + raise ValueError( + f"Tensor shape ({_tensor.shape} vs {tensor.shape}) or dtype ({_tensor.dtype} vs {tensor.dtype}) mismatch" + ) else: raise ValueError(f"{name} is not a tensor of {self.__class__.__name__}") diff --git a/skrl/multi_agents/jax/base.py b/skrl/multi_agents/jax/base.py index 99f9e055..1c88f212 100644 --- a/skrl/multi_agents/jax/base.py +++ b/skrl/multi_agents/jax/base.py @@ -17,14 +17,16 @@ class MultiAgent: - def __init__(self, - possible_agents: Sequence[str], - models: Mapping[str, Mapping[str, Model]], - memories: Optional[Mapping[str, Memory]] = None, - observation_spaces: Optional[Mapping[str, Union[int, Sequence[int], gymnasium.Space]]] = None, - action_spaces: Optional[Mapping[str, Union[int, Sequence[int], gymnasium.Space]]] = None, - device: Optional[Union[str, jax.Device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + possible_agents: Sequence[str], + models: Mapping[str, Mapping[str, Model]], + memories: Optional[Mapping[str, Memory]] = None, + observation_spaces: Optional[Mapping[str, Union[int, Sequence[int], gymnasium.Space]]] = None, + action_spaces: Optional[Mapping[str, Union[int, Sequence[int], gymnasium.Space]]] = None, + device: Optional[Union[str, jax.Device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Base class that represent a RL multi-agent :param possible_agents: Name of all possible agents the environment could generate @@ -61,7 +63,7 @@ def __init__(self, else: self.device = device if type(device) == str: - device_type, device_index = f"{device}:0".split(':')[:2] + device_type, device_index = f"{device}:0".split(":")[:2] self.device = jax.devices(device_type)[int(device_index)] # convert the models to their respective device @@ -84,7 +86,7 @@ def __init__(self, self.checkpoint_modules = {uid: {} for uid in self.possible_agents} self.checkpoint_interval = self.cfg.get("experiment", {}).get("checkpoint_interval", 1000) self.checkpoint_store_separately = self.cfg.get("experiment", {}).get("store_separately", False) - self.checkpoint_best_modules = {"timestep": 0, "reward": -2 ** 31, "saved": True, "modules": {}} + self.checkpoint_best_modules = {"timestep": 0, "reward": -(2**31), "saved": True, "modules": {}} # experiment directory directory = self.cfg.get("experiment", {}).get("directory", "") @@ -178,10 +180,14 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: if self.cfg.get("experiment", {}).get("wandb", False): # save experiment configuration try: - models_cfg = {uid: {k: v.net._modules for (k, v) in self.models[uid].items()} for uid in self.possible_agents} + models_cfg = { + uid: {k: v.net._modules for (k, v) in self.models[uid].items()} for uid in self.possible_agents + } except AttributeError: - models_cfg = {uid: {k: v._modules for (k, v) in self.models[uid].items()} for uid in self.possible_agents} - wandb_config={**self.cfg, **trainer_cfg, **models_cfg} + models_cfg = { + uid: {k: v._modules for (k, v) in self.models[uid].items()} for uid in self.possible_agents + } + wandb_config = {**self.cfg, **trainer_cfg, **models_cfg} # set default values wandb_kwargs = copy.deepcopy(self.cfg.get("experiment", {}).get("wandb_kwargs", {})) wandb_kwargs.setdefault("name", os.path.split(self.experiment_dir)[-1]) @@ -190,6 +196,7 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: wandb_kwargs["config"].update(wandb_config) # init Weights & Biases import wandb + wandb.init(**wandb_kwargs) # main entry to log data for consumption and visualization by TensorBoard @@ -200,6 +207,7 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: # tensorboard via torch SummaryWriter try: from torch.utils.tensorboard import SummaryWriter + self.writer = SummaryWriter(log_dir=self.experiment_dir) except ImportError as e: pass @@ -223,6 +231,7 @@ def add_scalar(self, tag, value, step): if self.writer is None: try: import tensorboardX + self.writer = tensorboardX.SummaryWriter(log_dir=self.experiment_dir) except ImportError as e: pass @@ -290,12 +299,19 @@ def write_checkpoint(self, timestep: int, timesteps: int) -> None: if self.checkpoint_store_separately: for uid in self.possible_agents: for name, module in self.checkpoint_modules[uid].items(): - with open(os.path.join(self.experiment_dir, "checkpoints", f"{uid}_{name}_{tag}.pickle"), "wb") as file: + with open( + os.path.join(self.experiment_dir, "checkpoints", f"{uid}_{name}_{tag}.pickle"), "wb" + ) as file: pickle.dump(flax.serialization.to_bytes(self._get_internal_value(module)), file, protocol=4) # whole agent else: - modules = {uid: {name: flax.serialization.to_bytes(self._get_internal_value(module)) for name, module in self.checkpoint_modules[uid].items()} \ - for uid in self.possible_agents} + modules = { + uid: { + name: flax.serialization.to_bytes(self._get_internal_value(module)) + for name, module in self.checkpoint_modules[uid].items() + } + for uid in self.possible_agents + } with open(os.path.join(self.experiment_dir, "checkpoints", f"agent_{tag}.pickle"), "wb") as file: pickle.dump(modules, file, protocol=4) @@ -306,17 +322,30 @@ def write_checkpoint(self, timestep: int, timesteps: int) -> None: if self.checkpoint_store_separately: for uid in self.possible_agents: for name, module in self.checkpoint_modules.items(): - with open(os.path.join(self.experiment_dir, "checkpoints", f"best_{uid}_{name}.pickle"), "wb") as file: - pickle.dump(flax.serialization.to_bytes(self.checkpoint_best_modules["modules"][uid][name]), file, protocol=4) + with open( + os.path.join(self.experiment_dir, "checkpoints", f"best_{uid}_{name}.pickle"), "wb" + ) as file: + pickle.dump( + flax.serialization.to_bytes(self.checkpoint_best_modules["modules"][uid][name]), + file, + protocol=4, + ) # whole agent else: - modules = {uid: {name: flax.serialization.to_bytes(self.checkpoint_best_modules["modules"][uid][name]) \ - for name in self.checkpoint_modules[uid].keys()} for uid in self.possible_agents} + modules = { + uid: { + name: flax.serialization.to_bytes(self.checkpoint_best_modules["modules"][uid][name]) + for name in self.checkpoint_modules[uid].keys() + } + for uid in self.possible_agents + } with open(os.path.join(self.experiment_dir, "checkpoints", "best_agent.pickle"), "wb") as file: pickle.dump(modules, file, protocol=4) self.checkpoint_best_modules["saved"] = True - def act(self, states: Mapping[str, Union[np.ndarray, jax.Array]], timestep: int, timesteps: int) -> Union[np.ndarray, jax.Array]: + def act( + self, states: Mapping[str, Union[np.ndarray, jax.Array]], timestep: int, timesteps: int + ) -> Union[np.ndarray, jax.Array]: """Process the environment's states to make a decision (actions) using the main policy :param states: Environment's states @@ -333,16 +362,18 @@ def act(self, states: Mapping[str, Union[np.ndarray, jax.Array]], timestep: int, """ raise NotImplementedError - def record_transition(self, - states: Mapping[str, Union[np.ndarray, jax.Array]], - actions: Mapping[str, Union[np.ndarray, jax.Array]], - rewards: Mapping[str, Union[np.ndarray, jax.Array]], - next_states: Mapping[str, Union[np.ndarray, jax.Array]], - terminated: Mapping[str, Union[np.ndarray, jax.Array]], - truncated: Mapping[str, Union[np.ndarray, jax.Array]], - infos: Mapping[str, Any], - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Mapping[str, Union[np.ndarray, jax.Array]], + actions: Mapping[str, Union[np.ndarray, jax.Array]], + rewards: Mapping[str, Union[np.ndarray, jax.Array]], + next_states: Mapping[str, Union[np.ndarray, jax.Array]], + terminated: Mapping[str, Union[np.ndarray, jax.Array]], + truncated: Mapping[str, Union[np.ndarray, jax.Array]], + infos: Mapping[str, Any], + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory (to be implemented by the inheriting classes) Inheriting classes must call this method to record episode information (rewards, timesteps, etc.). @@ -443,8 +474,13 @@ def save(self, path: str) -> None: :param path: Path to save the model to :type path: str """ - modules = {uid: {name: flax.serialization.to_bytes(self._get_internal_value(module)) \ - for name, module in self.checkpoint_modules[uid].items()} for uid in self.possible_agents} + modules = { + uid: { + name: flax.serialization.to_bytes(self._get_internal_value(module)) + for name, module in self.checkpoint_modules[uid].items() + } + for uid in self.possible_agents + } # HACK: Does it make sense to use https://github.com/google/orbax # file.write(flax.serialization.to_bytes(modules)) @@ -475,11 +511,13 @@ def load(self, path: str) -> None: else: logger.warning(f"Cannot load the {uid}:{name} module. The agent doesn't have such an instance") - def migrate(self, - path: str, - name_map: Mapping[str, Mapping[str, str]] = {}, - auto_mapping: bool = True, - verbose: bool = False) -> bool: + def migrate( + self, + path: str, + name_map: Mapping[str, Mapping[str, str]] = {}, + auto_mapping: bool = True, + verbose: bool = False, + ) -> bool: """Migrate the specified extrernal checkpoint to the current agent :raises NotImplementedError: Not yet implemented @@ -509,13 +547,17 @@ def post_interaction(self, timestep: int, timesteps: int) -> None: # update best models and write checkpoints if timestep > 1 and self.checkpoint_interval > 0 and not timestep % self.checkpoint_interval: # update best models - reward = np.mean(self.tracking_data.get("Reward / Total reward (mean)", -2 ** 31)) + reward = np.mean(self.tracking_data.get("Reward / Total reward (mean)", -(2**31))) if reward > self.checkpoint_best_modules["reward"]: self.checkpoint_best_modules["timestep"] = timestep self.checkpoint_best_modules["reward"] = reward self.checkpoint_best_modules["saved"] = False - self.checkpoint_best_modules["modules"] = {uid: {k: copy.deepcopy(self._get_internal_value(v)) \ - for k, v in self.checkpoint_modules[uid].items()} for uid in self.possible_agents} + self.checkpoint_best_modules["modules"] = { + uid: { + k: copy.deepcopy(self._get_internal_value(v)) for k, v in self.checkpoint_modules[uid].items() + } + for uid in self.possible_agents + } # write checkpoints self.write_checkpoint(timestep, timesteps) diff --git a/skrl/multi_agents/jax/ippo/ippo.py b/skrl/multi_agents/jax/ippo/ippo.py index 992d18e6..d84b674d 100644 --- a/skrl/multi_agents/jax/ippo/ippo.py +++ b/skrl/multi_agents/jax/ippo/ippo.py @@ -66,12 +66,14 @@ # fmt: on -def compute_gae(rewards: np.ndarray, - dones: np.ndarray, - values: np.ndarray, - next_values: np.ndarray, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> np.ndarray: +def compute_gae( + rewards: np.ndarray, + dones: np.ndarray, + values: np.ndarray, + next_values: np.ndarray, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, +) -> np.ndarray: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -98,7 +100,9 @@ def compute_gae(rewards: np.ndarray, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else next_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -107,14 +111,17 @@ def compute_gae(rewards: np.ndarray, return returns, advantages + # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @jax.jit -def _compute_gae(rewards: jax.Array, - dones: jax.Array, - values: jax.Array, - next_values: jax.Array, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> jax.Array: +def _compute_gae( + rewards: jax.Array, + dones: jax.Array, + values: jax.Array, + next_values: jax.Array, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, +) -> jax.Array: advantage = 0 advantages = jnp.zeros_like(rewards) not_dones = jnp.logical_not(dones) @@ -123,7 +130,9 @@ def _compute_gae(rewards: jax.Array, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else next_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages = advantages.at[i].set(advantage) # returns computation returns = advantages + values @@ -132,19 +141,24 @@ def _compute_gae(rewards: jax.Array, return returns, advantages + @functools.partial(jax.jit, static_argnames=("policy_act", "get_entropy", "entropy_loss_scale")) -def _update_policy(policy_act, - policy_state_dict, - sampled_states, - sampled_actions, - sampled_log_prob, - sampled_advantages, - ratio_clip, - get_entropy, - entropy_loss_scale): +def _update_policy( + policy_act, + policy_state_dict, + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_advantages, + ratio_clip, + get_entropy, + entropy_loss_scale, +): # compute policy loss def _policy_loss(params): - _, next_log_prob, outputs = policy_act({"states": sampled_states, "taken_actions": sampled_actions}, "policy", params) + _, next_log_prob, outputs = policy_act( + {"states": sampled_states, "taken_actions": sampled_actions}, "policy", params + ) # compute approximate KL divergence ratio = next_log_prob - sampled_log_prob @@ -162,19 +176,24 @@ def _policy_loss(params): return -jnp.minimum(surrogate, surrogate_clipped).mean(), (entropy_loss, kl_divergence, outputs["stddev"]) - (policy_loss, (entropy_loss, kl_divergence, stddev)), grad = jax.value_and_grad(_policy_loss, has_aux=True)(policy_state_dict.params) + (policy_loss, (entropy_loss, kl_divergence, stddev)), grad = jax.value_and_grad(_policy_loss, has_aux=True)( + policy_state_dict.params + ) return grad, policy_loss, entropy_loss, kl_divergence, stddev + @functools.partial(jax.jit, static_argnames=("value_act", "clip_predicted_values")) -def _update_value(value_act, - value_state_dict, - sampled_states, - sampled_values, - sampled_returns, - value_loss_scale, - clip_predicted_values, - value_clip): +def _update_value( + value_act, + value_state_dict, + sampled_states, + sampled_values, + sampled_returns, + value_loss_scale, + clip_predicted_values, + value_clip, +): # compute value loss def _value_loss(params): predicted_values, _, _ = value_act({"states": sampled_states}, "value", params) @@ -188,14 +207,16 @@ def _value_loss(params): class IPPO(MultiAgent): - def __init__(self, - possible_agents: Sequence[str], - models: Mapping[str, Model], - memories: Optional[Mapping[str, Memory]] = None, - observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, - action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, - device: Optional[Union[str, jax.Device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + possible_agents: Sequence[str], + models: Mapping[str, Model], + memories: Optional[Mapping[str, Memory]] = None, + observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + device: Optional[Union[str, jax.Device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Independent Proximal Policy Optimization (IPPO) https://arxiv.org/abs/2011.09533 @@ -220,13 +241,15 @@ def __init__(self, # _cfg = copy.deepcopy(IPPO_DEFAULT_CONFIG) # TODO: TypeError: cannot pickle 'jax.Device' object _cfg = IPPO_DEFAULT_CONFIG _cfg.update(cfg if cfg is not None else {}) - super().__init__(possible_agents=possible_agents, - models=models, - memories=memories, - observation_spaces=observation_spaces, - action_spaces=action_spaces, - device=device, - cfg=_cfg) + super().__init__( + possible_agents=possible_agents, + models=models, + memories=memories, + observation_spaces=observation_spaces, + action_spaces=action_spaces, + device=device, + cfg=_cfg, + ) # models self.policies = {uid: self.models[uid].get("policy", None) for uid in self.possible_agents} @@ -295,12 +318,20 @@ def __init__(self, if self._learning_rate_scheduler[uid] is not None: if self._learning_rate_scheduler[uid] == KLAdaptiveLR: scale = False - self.schedulers[uid] = self._learning_rate_scheduler[uid](self._learning_rate[uid], **self._learning_rate_scheduler_kwargs[uid]) + self.schedulers[uid] = self._learning_rate_scheduler[uid]( + self._learning_rate[uid], **self._learning_rate_scheduler_kwargs[uid] + ) else: - self._learning_rate[uid] = self._learning_rate_scheduler[uid](self._learning_rate[uid], **self._learning_rate_scheduler_kwargs[uid]) + self._learning_rate[uid] = self._learning_rate_scheduler[uid]( + self._learning_rate[uid], **self._learning_rate_scheduler_kwargs[uid] + ) # optimizer - self.policy_optimizer[uid] = Adam(model=policy, lr=self._learning_rate[uid], grad_norm_clip=self._grad_norm_clip[uid], scale=scale) - self.value_optimizer[uid] = Adam(model=value, lr=self._learning_rate[uid], grad_norm_clip=self._grad_norm_clip[uid], scale=scale) + self.policy_optimizer[uid] = Adam( + model=policy, lr=self._learning_rate[uid], grad_norm_clip=self._grad_norm_clip[uid], scale=scale + ) + self.value_optimizer[uid] = Adam( + model=value, lr=self._learning_rate[uid], grad_norm_clip=self._grad_norm_clip[uid], scale=scale + ) self.checkpoint_modules[uid]["policy_optimizer"] = self.policy_optimizer[uid] self.checkpoint_modules[uid]["value_optimizer"] = self.value_optimizer[uid] @@ -319,8 +350,7 @@ def __init__(self, self._value_preprocessor[uid] = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -349,7 +379,9 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: if self.values[uid] is not None: self.values[uid].apply = jax.jit(self.values[uid].apply, static_argnums=2) - def act(self, states: Mapping[str, Union[np.ndarray, jax.Array]], timestep: int, timesteps: int) -> Union[np.ndarray, jax.Array]: + def act( + self, states: Mapping[str, Union[np.ndarray, jax.Array]], timestep: int, timesteps: int + ) -> Union[np.ndarray, jax.Array]: """Process the environment's states to make a decision (actions) using the main policies :param states: Environment's states @@ -368,7 +400,10 @@ def act(self, states: Mapping[str, Union[np.ndarray, jax.Array]], timestep: int, # return self.policy.random_act({"states": states}, role="policy") # sample stochastic actions - data = [self.policies[uid].act({"states": self._state_preprocessor[uid](states[uid])}, role="policy") for uid in self.possible_agents] + data = [ + self.policies[uid].act({"states": self._state_preprocessor[uid](states[uid])}, role="policy") + for uid in self.possible_agents + ] actions = {uid: d[0] for uid, d in zip(self.possible_agents, data)} log_prob = {uid: d[1] for uid, d in zip(self.possible_agents, data)} @@ -382,16 +417,18 @@ def act(self, states: Mapping[str, Union[np.ndarray, jax.Array]], timestep: int, return actions, log_prob, outputs - def record_transition(self, - states: Mapping[str, Union[np.ndarray, jax.Array]], - actions: Mapping[str, Union[np.ndarray, jax.Array]], - rewards: Mapping[str, Union[np.ndarray, jax.Array]], - next_states: Mapping[str, Union[np.ndarray, jax.Array]], - terminated: Mapping[str, Union[np.ndarray, jax.Array]], - truncated: Mapping[str, Union[np.ndarray, jax.Array]], - infos: Mapping[str, Any], - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Mapping[str, Union[np.ndarray, jax.Array]], + actions: Mapping[str, Union[np.ndarray, jax.Array]], + rewards: Mapping[str, Union[np.ndarray, jax.Array]], + next_states: Mapping[str, Union[np.ndarray, jax.Array]], + terminated: Mapping[str, Union[np.ndarray, jax.Array]], + truncated: Mapping[str, Union[np.ndarray, jax.Array]], + infos: Mapping[str, Any], + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -413,7 +450,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memories: self._current_next_states = next_states @@ -424,7 +463,9 @@ def record_transition(self, rewards[uid] = self._rewards_shaper(rewards[uid], timestep, timesteps) # compute values - values, _, _ = self.values[uid].act({"states": self._state_preprocessor[uid](states[uid])}, role="value") + values, _, _ = self.values[uid].act( + {"states": self._state_preprocessor[uid](states[uid])}, role="value" + ) if not self._jax: # numpy backend values = jax.device_get(values) values = self._value_preprocessor[uid](values, inverse=True) @@ -434,8 +475,16 @@ def record_transition(self, rewards[uid] += self._discount_factor[uid] * values * truncated[uid] # storage transition in memory - self.memories[uid].add_samples(states=states[uid], actions=actions[uid], rewards=rewards[uid], next_states=next_states[uid], - terminated=terminated[uid], truncated=truncated[uid], log_prob=self._current_log_prob[uid], values=values) + self.memories[uid].add_samples( + states=states[uid], + actions=actions[uid], + rewards=rewards[uid], + next_states=next_states[uid], + terminated=terminated[uid], + truncated=truncated[uid], + log_prob=self._current_log_prob[uid], + values=values, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -479,7 +528,9 @@ def _update(self, timestep: int, timesteps: int) -> None: # compute returns and advantages value.training = False - last_values, _, _ = value.act({"states": self._state_preprocessor[uid](self._current_next_states[uid])}, role="value") # TODO: .float() + last_values, _, _ = value.act( + {"states": self._state_preprocessor[uid](self._current_next_states[uid])}, role="value" + ) # TODO: .float() value.training = True if not self._jax: # numpy backend last_values = jax.device_get(last_values) @@ -487,19 +538,23 @@ def _update(self, timestep: int, timesteps: int) -> None: values = memory.get_tensor_by_name("values") if self._jax: - returns, advantages = _compute_gae(rewards=memory.get_tensor_by_name("rewards"), - dones=memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor[uid], - lambda_coefficient=self._lambda[uid]) + returns, advantages = _compute_gae( + rewards=memory.get_tensor_by_name("rewards"), + dones=memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor[uid], + lambda_coefficient=self._lambda[uid], + ) else: - returns, advantages = compute_gae(rewards=memory.get_tensor_by_name("rewards"), - dones=memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor[uid], - lambda_coefficient=self._lambda[uid]) + returns, advantages = compute_gae( + rewards=memory.get_tensor_by_name("rewards"), + dones=memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor[uid], + lambda_coefficient=self._lambda[uid], + ) memory.set_tensor_by_name("values", self._value_preprocessor[uid](values, train=True)) memory.set_tensor_by_name("returns", self._value_preprocessor[uid](returns, train=True)) @@ -517,20 +572,29 @@ def _update(self, timestep: int, timesteps: int) -> None: kl_divergences = [] # mini-batches loop - for sampled_states, sampled_actions, sampled_log_prob, sampled_values, sampled_returns, sampled_advantages in sampled_batches: + for ( + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_values, + sampled_returns, + sampled_advantages, + ) in sampled_batches: sampled_states = self._state_preprocessor[uid](sampled_states, train=not epoch) # compute policy loss - grad, policy_loss, entropy_loss, kl_divergence, stddev = _update_policy(policy.act, - policy.state_dict, - sampled_states, - sampled_actions, - sampled_log_prob, - sampled_advantages, - self._ratio_clip[uid], - policy.get_entropy, - self._entropy_loss_scale[uid]) + grad, policy_loss, entropy_loss, kl_divergence, stddev = _update_policy( + policy.act, + policy.state_dict, + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_advantages, + self._ratio_clip[uid], + policy.get_entropy, + self._entropy_loss_scale[uid], + ) kl_divergences.append(kl_divergence.item()) @@ -541,22 +605,28 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (policy) if config.jax.is_distributed: grad = policy.reduce_parameters(grad) - self.policy_optimizer[uid] = self.policy_optimizer[uid].step(grad, policy, self.schedulers[uid]._lr if self.schedulers[uid] else None) + self.policy_optimizer[uid] = self.policy_optimizer[uid].step( + grad, policy, self.schedulers[uid]._lr if self.schedulers[uid] else None + ) # compute value loss - grad, value_loss = _update_value(value.act, - value.state_dict, - sampled_states, - sampled_values, - sampled_returns, - self._value_loss_scale[uid], - self._clip_predicted_values[uid], - self._value_clip[uid]) + grad, value_loss = _update_value( + value.act, + value.state_dict, + sampled_states, + sampled_values, + sampled_returns, + self._value_loss_scale[uid], + self._clip_predicted_values[uid], + self._value_clip[uid], + ) # optimization step (value) if config.jax.is_distributed: grad = value.reduce_parameters(grad) - self.value_optimizer[uid] = self.value_optimizer[uid].step(grad, value, self.schedulers[uid]._lr if self.schedulers[uid] else None) + self.value_optimizer[uid] = self.value_optimizer[uid].step( + grad, value, self.schedulers[uid]._lr if self.schedulers[uid] else None + ) # update cumulative losses cumulative_policy_loss += policy_loss.item() @@ -570,15 +640,24 @@ def _update(self, timestep: int, timesteps: int) -> None: kl = np.mean(kl_divergences) # reduce (collect from all workers/processes) KL in distributed runs if config.jax.is_distributed: - kl = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(kl.reshape(1)).item() + kl = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(kl.reshape(1)).item() kl /= config.jax.world_size self.schedulers[uid].step(kl) # record data - self.track_data(f"Loss / Policy loss ({uid})", cumulative_policy_loss / (self._learning_epochs[uid] * self._mini_batches[uid])) - self.track_data(f"Loss / Value loss ({uid})", cumulative_value_loss / (self._learning_epochs[uid] * self._mini_batches[uid])) + self.track_data( + f"Loss / Policy loss ({uid})", + cumulative_policy_loss / (self._learning_epochs[uid] * self._mini_batches[uid]), + ) + self.track_data( + f"Loss / Value loss ({uid})", + cumulative_value_loss / (self._learning_epochs[uid] * self._mini_batches[uid]), + ) if self._entropy_loss_scale: - self.track_data(f"Loss / Entropy loss ({uid})", cumulative_entropy_loss / (self._learning_epochs[uid] * self._mini_batches[uid])) + self.track_data( + f"Loss / Entropy loss ({uid})", + cumulative_entropy_loss / (self._learning_epochs[uid] * self._mini_batches[uid]), + ) self.track_data(f"Policy / Standard deviation ({uid})", stddev.mean().item()) diff --git a/skrl/multi_agents/jax/mappo/mappo.py b/skrl/multi_agents/jax/mappo/mappo.py index ee384e52..2d1db277 100644 --- a/skrl/multi_agents/jax/mappo/mappo.py +++ b/skrl/multi_agents/jax/mappo/mappo.py @@ -68,12 +68,14 @@ # fmt: on -def compute_gae(rewards: np.ndarray, - dones: np.ndarray, - values: np.ndarray, - next_values: np.ndarray, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> np.ndarray: +def compute_gae( + rewards: np.ndarray, + dones: np.ndarray, + values: np.ndarray, + next_values: np.ndarray, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, +) -> np.ndarray: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -100,7 +102,9 @@ def compute_gae(rewards: np.ndarray, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else next_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -109,14 +113,17 @@ def compute_gae(rewards: np.ndarray, return returns, advantages + # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @jax.jit -def _compute_gae(rewards: jax.Array, - dones: jax.Array, - values: jax.Array, - next_values: jax.Array, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> jax.Array: +def _compute_gae( + rewards: jax.Array, + dones: jax.Array, + values: jax.Array, + next_values: jax.Array, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, +) -> jax.Array: advantage = 0 advantages = jnp.zeros_like(rewards) not_dones = jnp.logical_not(dones) @@ -125,7 +132,9 @@ def _compute_gae(rewards: jax.Array, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else next_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages = advantages.at[i].set(advantage) # returns computation returns = advantages + values @@ -134,19 +143,24 @@ def _compute_gae(rewards: jax.Array, return returns, advantages + @functools.partial(jax.jit, static_argnames=("policy_act", "get_entropy", "entropy_loss_scale")) -def _update_policy(policy_act, - policy_state_dict, - sampled_states, - sampled_actions, - sampled_log_prob, - sampled_advantages, - ratio_clip, - get_entropy, - entropy_loss_scale): +def _update_policy( + policy_act, + policy_state_dict, + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_advantages, + ratio_clip, + get_entropy, + entropy_loss_scale, +): # compute policy loss def _policy_loss(params): - _, next_log_prob, outputs = policy_act({"states": sampled_states, "taken_actions": sampled_actions}, "policy", params) + _, next_log_prob, outputs = policy_act( + {"states": sampled_states, "taken_actions": sampled_actions}, "policy", params + ) # compute approximate KL divergence ratio = next_log_prob - sampled_log_prob @@ -164,19 +178,24 @@ def _policy_loss(params): return -jnp.minimum(surrogate, surrogate_clipped).mean(), (entropy_loss, kl_divergence, outputs["stddev"]) - (policy_loss, (entropy_loss, kl_divergence, stddev)), grad = jax.value_and_grad(_policy_loss, has_aux=True)(policy_state_dict.params) + (policy_loss, (entropy_loss, kl_divergence, stddev)), grad = jax.value_and_grad(_policy_loss, has_aux=True)( + policy_state_dict.params + ) return grad, policy_loss, entropy_loss, kl_divergence, stddev + @functools.partial(jax.jit, static_argnames=("value_act", "clip_predicted_values")) -def _update_value(value_act, - value_state_dict, - sampled_states, - sampled_values, - sampled_returns, - value_loss_scale, - clip_predicted_values, - value_clip): +def _update_value( + value_act, + value_state_dict, + sampled_states, + sampled_values, + sampled_returns, + value_loss_scale, + clip_predicted_values, + value_clip, +): # compute value loss def _value_loss(params): predicted_values, _, _ = value_act({"states": sampled_states}, "value", params) @@ -190,15 +209,17 @@ def _value_loss(params): class MAPPO(MultiAgent): - def __init__(self, - possible_agents: Sequence[str], - models: Mapping[str, Model], - memories: Optional[Mapping[str, Memory]] = None, - observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, - action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, - device: Optional[Union[str, jax.Device]] = None, - cfg: Optional[dict] = None, - shared_observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None) -> None: + def __init__( + self, + possible_agents: Sequence[str], + models: Mapping[str, Model], + memories: Optional[Mapping[str, Memory]] = None, + observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + device: Optional[Union[str, jax.Device]] = None, + cfg: Optional[dict] = None, + shared_observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + ) -> None: """Multi-Agent Proximal Policy Optimization (MAPPO) https://arxiv.org/abs/2103.01955 @@ -225,13 +246,15 @@ def __init__(self, # _cfg = copy.deepcopy(IPPO_DEFAULT_CONFIG) # TODO: TypeError: cannot pickle 'jax.Device' object _cfg = MAPPO_DEFAULT_CONFIG _cfg.update(cfg if cfg is not None else {}) - super().__init__(possible_agents=possible_agents, - models=models, - memories=memories, - observation_spaces=observation_spaces, - action_spaces=action_spaces, - device=device, - cfg=_cfg) + super().__init__( + possible_agents=possible_agents, + models=models, + memories=memories, + observation_spaces=observation_spaces, + action_spaces=action_spaces, + device=device, + cfg=_cfg, + ) self.shared_observation_spaces = shared_observation_spaces @@ -304,12 +327,20 @@ def __init__(self, if self._learning_rate_scheduler[uid] is not None: if self._learning_rate_scheduler[uid] == KLAdaptiveLR: scale = False - self.schedulers[uid] = self._learning_rate_scheduler[uid](self._learning_rate[uid], **self._learning_rate_scheduler_kwargs[uid]) + self.schedulers[uid] = self._learning_rate_scheduler[uid]( + self._learning_rate[uid], **self._learning_rate_scheduler_kwargs[uid] + ) else: - self._learning_rate[uid] = self._learning_rate_scheduler[uid](self._learning_rate[uid], **self._learning_rate_scheduler_kwargs[uid]) + self._learning_rate[uid] = self._learning_rate_scheduler[uid]( + self._learning_rate[uid], **self._learning_rate_scheduler_kwargs[uid] + ) # optimizer - self.policy_optimizer[uid] = Adam(model=policy, lr=self._learning_rate[uid], grad_norm_clip=self._grad_norm_clip[uid], scale=scale) - self.value_optimizer[uid] = Adam(model=value, lr=self._learning_rate[uid], grad_norm_clip=self._grad_norm_clip[uid], scale=scale) + self.policy_optimizer[uid] = Adam( + model=policy, lr=self._learning_rate[uid], grad_norm_clip=self._grad_norm_clip[uid], scale=scale + ) + self.value_optimizer[uid] = Adam( + model=value, lr=self._learning_rate[uid], grad_norm_clip=self._grad_norm_clip[uid], scale=scale + ) self.checkpoint_modules[uid]["policy_optimizer"] = self.policy_optimizer[uid] self.checkpoint_modules[uid]["value_optimizer"] = self.value_optimizer[uid] @@ -322,7 +353,9 @@ def __init__(self, self._state_preprocessor[uid] = self._empty_preprocessor if self._shared_state_preprocessor[uid] is not None: - self._shared_state_preprocessor[uid] = self._shared_state_preprocessor[uid](**self._shared_state_preprocessor_kwargs[uid]) + self._shared_state_preprocessor[uid] = self._shared_state_preprocessor[uid]( + **self._shared_state_preprocessor_kwargs[uid] + ) self.checkpoint_modules[uid]["shared_state_preprocessor"] = self._shared_state_preprocessor[uid] else: self._shared_state_preprocessor[uid] = self._empty_preprocessor @@ -334,8 +367,7 @@ def __init__(self, self._value_preprocessor[uid] = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -343,7 +375,9 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: if self.memories: for uid in self.possible_agents: self.memories[uid].create_tensor(name="states", size=self.observation_spaces[uid], dtype=jnp.float32) - self.memories[uid].create_tensor(name="shared_states", size=self.shared_observation_spaces[uid], dtype=jnp.float32) + self.memories[uid].create_tensor( + name="shared_states", size=self.shared_observation_spaces[uid], dtype=jnp.float32 + ) self.memories[uid].create_tensor(name="actions", size=self.action_spaces[uid], dtype=jnp.float32) self.memories[uid].create_tensor(name="rewards", size=1, dtype=jnp.float32) self.memories[uid].create_tensor(name="terminated", size=1, dtype=jnp.int8) @@ -353,7 +387,15 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self.memories[uid].create_tensor(name="advantages", size=1, dtype=jnp.float32) # tensors sampled during training - self._tensors_names = ["states", "shared_states", "actions", "log_prob", "values", "returns", "advantages"] + self._tensors_names = [ + "states", + "shared_states", + "actions", + "log_prob", + "values", + "returns", + "advantages", + ] # create temporary variables needed for storage and computation self._current_log_prob = [] @@ -365,7 +407,9 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: if self.values[uid] is not None: self.values[uid].apply = jax.jit(self.values[uid].apply, static_argnums=2) - def act(self, states: Mapping[str, Union[np.ndarray, jax.Array]], timestep: int, timesteps: int) -> Union[np.ndarray, jax.Array]: + def act( + self, states: Mapping[str, Union[np.ndarray, jax.Array]], timestep: int, timesteps: int + ) -> Union[np.ndarray, jax.Array]: """Process the environment's states to make a decision (actions) using the main policies :param states: Environment's states @@ -384,7 +428,10 @@ def act(self, states: Mapping[str, Union[np.ndarray, jax.Array]], timestep: int, # return self.policy.random_act({"states": states}, role="policy") # sample stochastic actions - data = [self.policies[uid].act({"states": self._state_preprocessor[uid](states[uid])}, role="policy") for uid in self.possible_agents] + data = [ + self.policies[uid].act({"states": self._state_preprocessor[uid](states[uid])}, role="policy") + for uid in self.possible_agents + ] actions = {uid: d[0] for uid, d in zip(self.possible_agents, data)} log_prob = {uid: d[1] for uid, d in zip(self.possible_agents, data)} @@ -398,16 +445,18 @@ def act(self, states: Mapping[str, Union[np.ndarray, jax.Array]], timestep: int, return actions, log_prob, outputs - def record_transition(self, - states: Mapping[str, Union[np.ndarray, jax.Array]], - actions: Mapping[str, Union[np.ndarray, jax.Array]], - rewards: Mapping[str, Union[np.ndarray, jax.Array]], - next_states: Mapping[str, Union[np.ndarray, jax.Array]], - terminated: Mapping[str, Union[np.ndarray, jax.Array]], - truncated: Mapping[str, Union[np.ndarray, jax.Array]], - infos: Mapping[str, Any], - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Mapping[str, Union[np.ndarray, jax.Array]], + actions: Mapping[str, Union[np.ndarray, jax.Array]], + rewards: Mapping[str, Union[np.ndarray, jax.Array]], + next_states: Mapping[str, Union[np.ndarray, jax.Array]], + terminated: Mapping[str, Union[np.ndarray, jax.Array]], + truncated: Mapping[str, Union[np.ndarray, jax.Array]], + infos: Mapping[str, Any], + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -429,7 +478,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memories: shared_states = infos["shared_states"] @@ -441,7 +492,9 @@ def record_transition(self, rewards[uid] = self._rewards_shaper(rewards[uid], timestep, timesteps) # compute values - values, _, _ = self.values[uid].act({"states": self._shared_state_preprocessor[uid](shared_states)}, role="value") + values, _, _ = self.values[uid].act( + {"states": self._shared_state_preprocessor[uid](shared_states)}, role="value" + ) if not self._jax: # numpy backend values = jax.device_get(values) values = self._value_preprocessor[uid](values, inverse=True) @@ -451,9 +504,17 @@ def record_transition(self, rewards[uid] += self._discount_factor[uid] * values * truncated[uid] # storage transition in memory - self.memories[uid].add_samples(states=states[uid], actions=actions[uid], rewards=rewards[uid], next_states=next_states[uid], - terminated=terminated[uid], truncated=truncated[uid], log_prob=self._current_log_prob[uid], values=values, - shared_states=shared_states) + self.memories[uid].add_samples( + states=states[uid], + actions=actions[uid], + rewards=rewards[uid], + next_states=next_states[uid], + terminated=terminated[uid], + truncated=truncated[uid], + log_prob=self._current_log_prob[uid], + values=values, + shared_states=shared_states, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -497,7 +558,9 @@ def _update(self, timestep: int, timesteps: int) -> None: # compute returns and advantages value.training = False - last_values, _, _ = value.act({"states": self._shared_state_preprocessor[uid](self._current_shared_next_states)}, role="value") # TODO: .float() + last_values, _, _ = value.act( + {"states": self._shared_state_preprocessor[uid](self._current_shared_next_states)}, role="value" + ) # TODO: .float() value.training = True if not self._jax: # numpy backend last_values = jax.device_get(last_values) @@ -505,19 +568,23 @@ def _update(self, timestep: int, timesteps: int) -> None: values = memory.get_tensor_by_name("values") if self._jax: - returns, advantages = _compute_gae(rewards=memory.get_tensor_by_name("rewards"), - dones=memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor[uid], - lambda_coefficient=self._lambda[uid]) + returns, advantages = _compute_gae( + rewards=memory.get_tensor_by_name("rewards"), + dones=memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor[uid], + lambda_coefficient=self._lambda[uid], + ) else: - returns, advantages = compute_gae(rewards=memory.get_tensor_by_name("rewards"), - dones=memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor[uid], - lambda_coefficient=self._lambda[uid]) + returns, advantages = compute_gae( + rewards=memory.get_tensor_by_name("rewards"), + dones=memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor[uid], + lambda_coefficient=self._lambda[uid], + ) memory.set_tensor_by_name("values", self._value_preprocessor[uid](values, train=True)) memory.set_tensor_by_name("returns", self._value_preprocessor[uid](returns, train=True)) @@ -535,22 +602,31 @@ def _update(self, timestep: int, timesteps: int) -> None: kl_divergences = [] # mini-batches loop - for sampled_states, sampled_shared_states, sampled_actions, sampled_log_prob, sampled_values, sampled_returns, sampled_advantages \ - in sampled_batches: + for ( + sampled_states, + sampled_shared_states, + sampled_actions, + sampled_log_prob, + sampled_values, + sampled_returns, + sampled_advantages, + ) in sampled_batches: sampled_states = self._state_preprocessor[uid](sampled_states, train=not epoch) sampled_shared_states = self._shared_state_preprocessor[uid](sampled_shared_states, train=not epoch) # compute policy loss - grad, policy_loss, entropy_loss, kl_divergence, stddev = _update_policy(policy.act, - policy.state_dict, - sampled_states, - sampled_actions, - sampled_log_prob, - sampled_advantages, - self._ratio_clip[uid], - policy.get_entropy, - self._entropy_loss_scale[uid]) + grad, policy_loss, entropy_loss, kl_divergence, stddev = _update_policy( + policy.act, + policy.state_dict, + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_advantages, + self._ratio_clip[uid], + policy.get_entropy, + self._entropy_loss_scale[uid], + ) kl_divergences.append(kl_divergence.item()) @@ -561,22 +637,28 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (policy) if config.jax.is_distributed: grad = policy.reduce_parameters(grad) - self.policy_optimizer[uid] = self.policy_optimizer[uid].step(grad, policy, self.schedulers[uid]._lr if self.schedulers[uid] else None) + self.policy_optimizer[uid] = self.policy_optimizer[uid].step( + grad, policy, self.schedulers[uid]._lr if self.schedulers[uid] else None + ) # compute value loss - grad, value_loss = _update_value(value.act, - value.state_dict, - sampled_shared_states, - sampled_values, - sampled_returns, - self._value_loss_scale[uid], - self._clip_predicted_values[uid], - self._value_clip[uid]) + grad, value_loss = _update_value( + value.act, + value.state_dict, + sampled_shared_states, + sampled_values, + sampled_returns, + self._value_loss_scale[uid], + self._clip_predicted_values[uid], + self._value_clip[uid], + ) # optimization step (value) if config.jax.is_distributed: grad = value.reduce_parameters(grad) - self.value_optimizer[uid] = self.value_optimizer[uid].step(grad, value, self.schedulers[uid]._lr if self.schedulers[uid] else None) + self.value_optimizer[uid] = self.value_optimizer[uid].step( + grad, value, self.schedulers[uid]._lr if self.schedulers[uid] else None + ) # update cumulative losses cumulative_policy_loss += policy_loss.item() @@ -590,15 +672,24 @@ def _update(self, timestep: int, timesteps: int) -> None: kl = np.mean(kl_divergences) # reduce (collect from all workers/processes) KL in distributed runs if config.jax.is_distributed: - kl = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(kl.reshape(1)).item() + kl = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(kl.reshape(1)).item() kl /= config.jax.world_size self.schedulers[uid].step(kl) # record data - self.track_data(f"Loss / Policy loss ({uid})", cumulative_policy_loss / (self._learning_epochs[uid] * self._mini_batches[uid])) - self.track_data(f"Loss / Value loss ({uid})", cumulative_value_loss / (self._learning_epochs[uid] * self._mini_batches[uid])) + self.track_data( + f"Loss / Policy loss ({uid})", + cumulative_policy_loss / (self._learning_epochs[uid] * self._mini_batches[uid]), + ) + self.track_data( + f"Loss / Value loss ({uid})", + cumulative_value_loss / (self._learning_epochs[uid] * self._mini_batches[uid]), + ) if self._entropy_loss_scale: - self.track_data(f"Loss / Entropy loss ({uid})", cumulative_entropy_loss / (self._learning_epochs[uid] * self._mini_batches[uid])) + self.track_data( + f"Loss / Entropy loss ({uid})", + cumulative_entropy_loss / (self._learning_epochs[uid] * self._mini_batches[uid]), + ) self.track_data(f"Policy / Standard deviation ({uid})", stddev.mean().item()) diff --git a/skrl/multi_agents/torch/base.py b/skrl/multi_agents/torch/base.py index 3eecf23f..f41bc5d3 100644 --- a/skrl/multi_agents/torch/base.py +++ b/skrl/multi_agents/torch/base.py @@ -17,14 +17,16 @@ class MultiAgent: - def __init__(self, - possible_agents: Sequence[str], - models: Mapping[str, Mapping[str, Model]], - memories: Optional[Mapping[str, Memory]] = None, - observation_spaces: Optional[Mapping[str, Union[int, Sequence[int], gymnasium.Space]]] = None, - action_spaces: Optional[Mapping[str, Union[int, Sequence[int], gymnasium.Space]]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + possible_agents: Sequence[str], + models: Mapping[str, Mapping[str, Model]], + memories: Optional[Mapping[str, Memory]] = None, + observation_spaces: Optional[Mapping[str, Union[int, Sequence[int], gymnasium.Space]]] = None, + action_spaces: Optional[Mapping[str, Union[int, Sequence[int], gymnasium.Space]]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Base class that represent a RL multi-agent :param possible_agents: Name of all possible agents the environment could generate @@ -53,7 +55,9 @@ def __init__(self, self.action_spaces = action_spaces self.cfg = cfg if cfg is not None else {} - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device) + self.device = ( + torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device) + ) # convert the models to their respective device for _models in self.models.values(): @@ -75,7 +79,7 @@ def __init__(self, self.checkpoint_modules = {uid: {} for uid in self.possible_agents} self.checkpoint_interval = self.cfg.get("experiment", {}).get("checkpoint_interval", 1000) self.checkpoint_store_separately = self.cfg.get("experiment", {}).get("store_separately", False) - self.checkpoint_best_modules = {"timestep": 0, "reward": -2 ** 31, "saved": True, "modules": {}} + self.checkpoint_best_modules = {"timestep": 0, "reward": -(2**31), "saved": True, "modules": {}} # experiment directory directory = self.cfg.get("experiment", {}).get("directory", "") @@ -166,10 +170,14 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: if self.cfg.get("experiment", {}).get("wandb", False): # save experiment configuration try: - models_cfg = {uid: {k: v.net._modules for (k, v) in self.models[uid].items()} for uid in self.possible_agents} + models_cfg = { + uid: {k: v.net._modules for (k, v) in self.models[uid].items()} for uid in self.possible_agents + } except AttributeError: - models_cfg = {uid: {k: v._modules for (k, v) in self.models[uid].items()} for uid in self.possible_agents} - wandb_config={**self.cfg, **trainer_cfg, **models_cfg} + models_cfg = { + uid: {k: v._modules for (k, v) in self.models[uid].items()} for uid in self.possible_agents + } + wandb_config = {**self.cfg, **trainer_cfg, **models_cfg} # set default values wandb_kwargs = copy.deepcopy(self.cfg.get("experiment", {}).get("wandb_kwargs", {})) wandb_kwargs.setdefault("name", os.path.split(self.experiment_dir)[-1]) @@ -178,6 +186,7 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: wandb_kwargs["config"].update(wandb_config) # init Weights & Biases import wandb + wandb.init(**wandb_kwargs) # main entry to log data for consumption and visualization by TensorBoard @@ -239,12 +248,16 @@ def write_checkpoint(self, timestep: int, timesteps: int) -> None: if self.checkpoint_store_separately: for uid in self.possible_agents: for name, module in self.checkpoint_modules[uid].items(): - torch.save(self._get_internal_value(module), - os.path.join(self.experiment_dir, "checkpoints", f"{uid}_{name}_{tag}.pt")) + torch.save( + self._get_internal_value(module), + os.path.join(self.experiment_dir, "checkpoints", f"{uid}_{name}_{tag}.pt"), + ) # whole agent else: - modules = {uid: {name: self._get_internal_value(module) for name, module in self.checkpoint_modules[uid].items()} \ - for uid in self.possible_agents} + modules = { + uid: {name: self._get_internal_value(module) for name, module in self.checkpoint_modules[uid].items()} + for uid in self.possible_agents + } torch.save(modules, os.path.join(self.experiment_dir, "checkpoints", f"agent_{tag}.pt")) # best modules @@ -253,12 +266,19 @@ def write_checkpoint(self, timestep: int, timesteps: int) -> None: if self.checkpoint_store_separately: for uid in self.possible_agents: for name in self.checkpoint_modules[uid].keys(): - torch.save(self.checkpoint_best_modules["modules"][uid][name], - os.path.join(self.experiment_dir, "checkpoints", f"best_{uid}_{name}.pt")) + torch.save( + self.checkpoint_best_modules["modules"][uid][name], + os.path.join(self.experiment_dir, "checkpoints", f"best_{uid}_{name}.pt"), + ) # whole agent else: - modules = {uid: {name: self.checkpoint_best_modules["modules"][uid][name] \ - for name in self.checkpoint_modules[uid].keys()} for uid in self.possible_agents} + modules = { + uid: { + name: self.checkpoint_best_modules["modules"][uid][name] + for name in self.checkpoint_modules[uid].keys() + } + for uid in self.possible_agents + } torch.save(modules, os.path.join(self.experiment_dir, "checkpoints", "best_agent.pt")) self.checkpoint_best_modules["saved"] = True @@ -279,16 +299,18 @@ def act(self, states: Mapping[str, torch.Tensor], timestep: int, timesteps: int) """ raise NotImplementedError - def record_transition(self, - states: Mapping[str, torch.Tensor], - actions: Mapping[str, torch.Tensor], - rewards: Mapping[str, torch.Tensor], - next_states: Mapping[str, torch.Tensor], - terminated: Mapping[str, torch.Tensor], - truncated: Mapping[str, torch.Tensor], - infos: Mapping[str, Any], - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Mapping[str, torch.Tensor], + actions: Mapping[str, torch.Tensor], + rewards: Mapping[str, torch.Tensor], + next_states: Mapping[str, torch.Tensor], + terminated: Mapping[str, torch.Tensor], + truncated: Mapping[str, torch.Tensor], + infos: Mapping[str, Any], + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory (to be implemented by the inheriting classes) Inheriting classes must call this method to record episode information (rewards, timesteps, etc.). @@ -381,8 +403,10 @@ def save(self, path: str) -> None: :param path: Path to save the model to :type path: str """ - modules = {uid: {name: self._get_internal_value(module) for name, module in self.checkpoint_modules[uid].items()} \ - for uid in self.possible_agents} + modules = { + uid: {name: self._get_internal_value(module) for name, module in self.checkpoint_modules[uid].items()} + for uid in self.possible_agents + } torch.save(modules, path) def load(self, path: str) -> None: @@ -414,11 +438,13 @@ def load(self, path: str) -> None: else: logger.warning(f"Cannot load the {uid}:{name} module. The agent doesn't have such an instance") - def migrate(self, - path: str, - name_map: Mapping[str, Mapping[str, str]] = {}, - auto_mapping: bool = True, - verbose: bool = False) -> bool: + def migrate( + self, + path: str, + name_map: Mapping[str, Mapping[str, str]] = {}, + auto_mapping: bool = True, + verbose: bool = False, + ) -> bool: """Migrate the specified extrernal checkpoint to the current agent The final storage device is determined by the constructor of the agent. @@ -467,13 +493,17 @@ def post_interaction(self, timestep: int, timesteps: int) -> None: # update best models and write checkpoints if timestep > 1 and self.checkpoint_interval > 0 and not timestep % self.checkpoint_interval: # update best models - reward = np.mean(self.tracking_data.get("Reward / Total reward (mean)", -2 ** 31)) + reward = np.mean(self.tracking_data.get("Reward / Total reward (mean)", -(2**31))) if reward > self.checkpoint_best_modules["reward"]: self.checkpoint_best_modules["timestep"] = timestep self.checkpoint_best_modules["reward"] = reward self.checkpoint_best_modules["saved"] = False - self.checkpoint_best_modules["modules"] = {uid: {k: copy.deepcopy(self._get_internal_value(v)) \ - for k, v in self.checkpoint_modules[uid].items()} for uid in self.possible_agents} + self.checkpoint_best_modules["modules"] = { + uid: { + k: copy.deepcopy(self._get_internal_value(v)) for k, v in self.checkpoint_modules[uid].items() + } + for uid in self.possible_agents + } # write checkpoints self.write_checkpoint(timestep, timesteps) diff --git a/skrl/multi_agents/torch/ippo/ippo.py b/skrl/multi_agents/torch/ippo/ippo.py index b84c7559..c8296fa8 100644 --- a/skrl/multi_agents/torch/ippo/ippo.py +++ b/skrl/multi_agents/torch/ippo/ippo.py @@ -67,14 +67,16 @@ class IPPO(MultiAgent): - def __init__(self, - possible_agents: Sequence[str], - models: Mapping[str, Model], - memories: Optional[Mapping[str, Memory]] = None, - observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, - action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + possible_agents: Sequence[str], + models: Mapping[str, Model], + memories: Optional[Mapping[str, Memory]] = None, + observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + ) -> None: """Independent Proximal Policy Optimization (IPPO) https://arxiv.org/abs/2011.09533 @@ -98,13 +100,15 @@ def __init__(self, """ _cfg = copy.deepcopy(IPPO_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(possible_agents=possible_agents, - models=models, - memories=memories, - observation_spaces=observation_spaces, - action_spaces=action_spaces, - device=device, - cfg=_cfg) + super().__init__( + possible_agents=possible_agents, + models=models, + memories=memories, + observation_spaces=observation_spaces, + action_spaces=action_spaces, + device=device, + cfg=_cfg, + ) # models self.policies = {uid: self.models[uid].get("policy", None) for uid in self.possible_agents} @@ -168,11 +172,14 @@ def __init__(self, if policy is value: optimizer = torch.optim.Adam(policy.parameters(), lr=self._learning_rate[uid]) else: - optimizer = torch.optim.Adam(itertools.chain(policy.parameters(), value.parameters()), - lr=self._learning_rate[uid]) + optimizer = torch.optim.Adam( + itertools.chain(policy.parameters(), value.parameters()), lr=self._learning_rate[uid] + ) self.optimizers[uid] = optimizer if self._learning_rate_scheduler[uid] is not None: - self.schedulers[uid] = self._learning_rate_scheduler[uid](optimizer, **self._learning_rate_scheduler_kwargs[uid]) + self.schedulers[uid] = self._learning_rate_scheduler[uid]( + optimizer, **self._learning_rate_scheduler_kwargs[uid] + ) self.checkpoint_modules[uid]["optimizer"] = self.optimizers[uid] @@ -190,8 +197,7 @@ def __init__(self, self._value_preprocessor[uid] = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -233,7 +239,10 @@ def act(self, states: Mapping[str, torch.Tensor], timestep: int, timesteps: int) # return self.policy.random_act({"states": states}, role="policy") # sample stochastic actions - data = [self.policies[uid].act({"states": self._state_preprocessor[uid](states[uid])}, role="policy") for uid in self.possible_agents] + data = [ + self.policies[uid].act({"states": self._state_preprocessor[uid](states[uid])}, role="policy") + for uid in self.possible_agents + ] actions = {uid: d[0] for uid, d in zip(self.possible_agents, data)} log_prob = {uid: d[1] for uid, d in zip(self.possible_agents, data)} @@ -243,16 +252,18 @@ def act(self, states: Mapping[str, torch.Tensor], timestep: int, timesteps: int) return actions, log_prob, outputs - def record_transition(self, - states: Mapping[str, torch.Tensor], - actions: Mapping[str, torch.Tensor], - rewards: Mapping[str, torch.Tensor], - next_states: Mapping[str, torch.Tensor], - terminated: Mapping[str, torch.Tensor], - truncated: Mapping[str, torch.Tensor], - infos: Mapping[str, Any], - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Mapping[str, torch.Tensor], + actions: Mapping[str, torch.Tensor], + rewards: Mapping[str, torch.Tensor], + next_states: Mapping[str, torch.Tensor], + terminated: Mapping[str, torch.Tensor], + truncated: Mapping[str, torch.Tensor], + infos: Mapping[str, Any], + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -274,7 +285,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memories: self._current_next_states = next_states @@ -285,7 +298,9 @@ def record_transition(self, rewards[uid] = self._rewards_shaper(rewards[uid], timestep, timesteps) # compute values - values, _, _ = self.values[uid].act({"states": self._state_preprocessor[uid](states[uid])}, role="value") + values, _, _ = self.values[uid].act( + {"states": self._state_preprocessor[uid](states[uid])}, role="value" + ) values = self._value_preprocessor[uid](values, inverse=True) # time-limit (truncation) boostrapping @@ -293,8 +308,16 @@ def record_transition(self, rewards[uid] += self._discount_factor[uid] * values * truncated[uid] # storage transition in memory - self.memories[uid].add_samples(states=states[uid], actions=actions[uid], rewards=rewards[uid], next_states=next_states[uid], - terminated=terminated[uid], truncated=truncated[uid], log_prob=self._current_log_prob[uid], values=values) + self.memories[uid].add_samples( + states=states[uid], + actions=actions[uid], + rewards=rewards[uid], + next_states=next_states[uid], + terminated=terminated[uid], + truncated=truncated[uid], + log_prob=self._current_log_prob[uid], + values=values, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -331,12 +354,15 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - def compute_gae(rewards: torch.Tensor, - dones: torch.Tensor, - values: torch.Tensor, - next_values: torch.Tensor, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> torch.Tensor: + + def compute_gae( + rewards: torch.Tensor, + dones: torch.Tensor, + values: torch.Tensor, + next_values: torch.Tensor, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, + ) -> torch.Tensor: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -363,7 +389,11 @@ def compute_gae(rewards: torch.Tensor, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else last_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] + - values[i] + + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -380,17 +410,21 @@ def compute_gae(rewards: torch.Tensor, # compute returns and advantages with torch.no_grad(): value.train(False) - last_values, _, _ = value.act({"states": self._state_preprocessor[uid](self._current_next_states[uid].float())}, role="value") + last_values, _, _ = value.act( + {"states": self._state_preprocessor[uid](self._current_next_states[uid].float())}, role="value" + ) value.train(True) last_values = self._value_preprocessor[uid](last_values, inverse=True) values = memory.get_tensor_by_name("values") - returns, advantages = compute_gae(rewards=memory.get_tensor_by_name("rewards"), - dones=memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor[uid], - lambda_coefficient=self._lambda[uid]) + returns, advantages = compute_gae( + rewards=memory.get_tensor_by_name("rewards"), + dones=memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor[uid], + lambda_coefficient=self._lambda[uid], + ) memory.set_tensor_by_name("values", self._value_preprocessor[uid](values, train=True)) memory.set_tensor_by_name("returns", self._value_preprocessor[uid](returns, train=True)) @@ -408,11 +442,20 @@ def compute_gae(rewards: torch.Tensor, kl_divergences = [] # mini-batches loop - for sampled_states, sampled_actions, sampled_log_prob, sampled_values, sampled_returns, sampled_advantages in sampled_batches: + for ( + sampled_states, + sampled_actions, + sampled_log_prob, + sampled_values, + sampled_returns, + sampled_advantages, + ) in sampled_batches: sampled_states = self._state_preprocessor[uid](sampled_states, train=not epoch) - _, next_log_prob, _ = policy.act({"states": sampled_states, "taken_actions": sampled_actions}, role="policy") + _, next_log_prob, _ = policy.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="policy" + ) # compute approximate KL divergence with torch.no_grad(): @@ -433,7 +476,9 @@ def compute_gae(rewards: torch.Tensor, # compute policy loss ratio = torch.exp(next_log_prob - sampled_log_prob) surrogate = sampled_advantages * ratio - surrogate_clipped = sampled_advantages * torch.clip(ratio, 1.0 - self._ratio_clip[uid], 1.0 + self._ratio_clip[uid]) + surrogate_clipped = sampled_advantages * torch.clip( + ratio, 1.0 - self._ratio_clip[uid], 1.0 + self._ratio_clip[uid] + ) policy_loss = -torch.min(surrogate, surrogate_clipped).mean() @@ -441,9 +486,9 @@ def compute_gae(rewards: torch.Tensor, predicted_values, _, _ = value.act({"states": sampled_states}, role="value") if self._clip_predicted_values: - predicted_values = sampled_values + torch.clip(predicted_values - sampled_values, - min=-self._value_clip[uid], - max=self._value_clip[uid]) + predicted_values = sampled_values + torch.clip( + predicted_values - sampled_values, min=-self._value_clip[uid], max=self._value_clip[uid] + ) value_loss = self._value_loss_scale[uid] * F.mse_loss(sampled_returns, predicted_values) # optimization step @@ -457,7 +502,9 @@ def compute_gae(rewards: torch.Tensor, if policy is value: nn.utils.clip_grad_norm_(policy.parameters(), self._grad_norm_clip[uid]) else: - nn.utils.clip_grad_norm_(itertools.chain(policy.parameters(), value.parameters()), self._grad_norm_clip[uid]) + nn.utils.clip_grad_norm_( + itertools.chain(policy.parameters(), value.parameters()), self._grad_norm_clip[uid] + ) self.optimizers[uid].step() # update cumulative losses @@ -479,12 +526,23 @@ def compute_gae(rewards: torch.Tensor, self.schedulers[uid].step() # record data - self.track_data(f"Loss / Policy loss ({uid})", cumulative_policy_loss / (self._learning_epochs[uid] * self._mini_batches[uid])) - self.track_data(f"Loss / Value loss ({uid})", cumulative_value_loss / (self._learning_epochs[uid] * self._mini_batches[uid])) + self.track_data( + f"Loss / Policy loss ({uid})", + cumulative_policy_loss / (self._learning_epochs[uid] * self._mini_batches[uid]), + ) + self.track_data( + f"Loss / Value loss ({uid})", + cumulative_value_loss / (self._learning_epochs[uid] * self._mini_batches[uid]), + ) if self._entropy_loss_scale: - self.track_data(f"Loss / Entropy loss ({uid})", cumulative_entropy_loss / (self._learning_epochs[uid] * self._mini_batches[uid])) - - self.track_data(f"Policy / Standard deviation ({uid})", policy.distribution(role="policy").stddev.mean().item()) + self.track_data( + f"Loss / Entropy loss ({uid})", + cumulative_entropy_loss / (self._learning_epochs[uid] * self._mini_batches[uid]), + ) + + self.track_data( + f"Policy / Standard deviation ({uid})", policy.distribution(role="policy").stddev.mean().item() + ) if self._learning_rate_scheduler[uid]: self.track_data(f"Learning / Learning rate ({uid})", self.schedulers[uid].get_last_lr()[0]) diff --git a/skrl/multi_agents/torch/mappo/mappo.py b/skrl/multi_agents/torch/mappo/mappo.py index d466c7bc..ae549c7d 100644 --- a/skrl/multi_agents/torch/mappo/mappo.py +++ b/skrl/multi_agents/torch/mappo/mappo.py @@ -69,15 +69,17 @@ class MAPPO(MultiAgent): - def __init__(self, - possible_agents: Sequence[str], - models: Mapping[str, Model], - memories: Optional[Mapping[str, Memory]] = None, - observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, - action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, - device: Optional[Union[str, torch.device]] = None, - cfg: Optional[dict] = None, - shared_observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None) -> None: + def __init__( + self, + possible_agents: Sequence[str], + models: Mapping[str, Model], + memories: Optional[Mapping[str, Memory]] = None, + observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + device: Optional[Union[str, torch.device]] = None, + cfg: Optional[dict] = None, + shared_observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gymnasium.Space]]] = None, + ) -> None: """Multi-Agent Proximal Policy Optimization (MAPPO) https://arxiv.org/abs/2103.01955 @@ -103,13 +105,15 @@ def __init__(self, """ _cfg = copy.deepcopy(MAPPO_DEFAULT_CONFIG) _cfg.update(cfg if cfg is not None else {}) - super().__init__(possible_agents=possible_agents, - models=models, - memories=memories, - observation_spaces=observation_spaces, - action_spaces=action_spaces, - device=device, - cfg=_cfg) + super().__init__( + possible_agents=possible_agents, + models=models, + memories=memories, + observation_spaces=observation_spaces, + action_spaces=action_spaces, + device=device, + cfg=_cfg, + ) self.shared_observation_spaces = shared_observation_spaces @@ -177,11 +181,14 @@ def __init__(self, if policy is value: optimizer = torch.optim.Adam(policy.parameters(), lr=self._learning_rate[uid]) else: - optimizer = torch.optim.Adam(itertools.chain(policy.parameters(), value.parameters()), - lr=self._learning_rate[uid]) + optimizer = torch.optim.Adam( + itertools.chain(policy.parameters(), value.parameters()), lr=self._learning_rate[uid] + ) self.optimizers[uid] = optimizer if self._learning_rate_scheduler[uid] is not None: - self.schedulers[uid] = self._learning_rate_scheduler[uid](optimizer, **self._learning_rate_scheduler_kwargs[uid]) + self.schedulers[uid] = self._learning_rate_scheduler[uid]( + optimizer, **self._learning_rate_scheduler_kwargs[uid] + ) self.checkpoint_modules[uid]["optimizer"] = self.optimizers[uid] @@ -193,7 +200,9 @@ def __init__(self, self._state_preprocessor[uid] = self._empty_preprocessor if self._shared_state_preprocessor[uid] is not None: - self._shared_state_preprocessor[uid] = self._shared_state_preprocessor[uid](**self._shared_state_preprocessor_kwargs[uid]) + self._shared_state_preprocessor[uid] = self._shared_state_preprocessor[uid]( + **self._shared_state_preprocessor_kwargs[uid] + ) self.checkpoint_modules[uid]["shared_state_preprocessor"] = self._shared_state_preprocessor[uid] else: self._shared_state_preprocessor[uid] = self._empty_preprocessor @@ -205,8 +214,7 @@ def __init__(self, self._value_preprocessor[uid] = self._empty_preprocessor def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: - """Initialize the agent - """ + """Initialize the agent""" super().init(trainer_cfg=trainer_cfg) self.set_mode("eval") @@ -214,7 +222,9 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: if self.memories: for uid in self.possible_agents: self.memories[uid].create_tensor(name="states", size=self.observation_spaces[uid], dtype=torch.float32) - self.memories[uid].create_tensor(name="shared_states", size=self.shared_observation_spaces[uid], dtype=torch.float32) + self.memories[uid].create_tensor( + name="shared_states", size=self.shared_observation_spaces[uid], dtype=torch.float32 + ) self.memories[uid].create_tensor(name="actions", size=self.action_spaces[uid], dtype=torch.float32) self.memories[uid].create_tensor(name="rewards", size=1, dtype=torch.float32) self.memories[uid].create_tensor(name="terminated", size=1, dtype=torch.bool) @@ -224,7 +234,15 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: self.memories[uid].create_tensor(name="advantages", size=1, dtype=torch.float32) # tensors sampled during training - self._tensors_names = ["states", "shared_states", "actions", "log_prob", "values", "returns", "advantages"] + self._tensors_names = [ + "states", + "shared_states", + "actions", + "log_prob", + "values", + "returns", + "advantages", + ] # create temporary variables needed for storage and computation self._current_log_prob = [] @@ -249,7 +267,10 @@ def act(self, states: Mapping[str, torch.Tensor], timestep: int, timesteps: int) # return self.policy.random_act({"states": states}, role="policy") # sample stochastic actions - data = [self.policies[uid].act({"states": self._state_preprocessor[uid](states[uid])}, role="policy") for uid in self.possible_agents] + data = [ + self.policies[uid].act({"states": self._state_preprocessor[uid](states[uid])}, role="policy") + for uid in self.possible_agents + ] actions = {uid: d[0] for uid, d in zip(self.possible_agents, data)} log_prob = {uid: d[1] for uid, d in zip(self.possible_agents, data)} @@ -259,16 +280,18 @@ def act(self, states: Mapping[str, torch.Tensor], timestep: int, timesteps: int) return actions, log_prob, outputs - def record_transition(self, - states: Mapping[str, torch.Tensor], - actions: Mapping[str, torch.Tensor], - rewards: Mapping[str, torch.Tensor], - next_states: Mapping[str, torch.Tensor], - terminated: Mapping[str, torch.Tensor], - truncated: Mapping[str, torch.Tensor], - infos: Mapping[str, Any], - timestep: int, - timesteps: int) -> None: + def record_transition( + self, + states: Mapping[str, torch.Tensor], + actions: Mapping[str, torch.Tensor], + rewards: Mapping[str, torch.Tensor], + next_states: Mapping[str, torch.Tensor], + terminated: Mapping[str, torch.Tensor], + truncated: Mapping[str, torch.Tensor], + infos: Mapping[str, Any], + timestep: int, + timesteps: int, + ) -> None: """Record an environment transition in memory :param states: Observations/states of the environment used to make the decision @@ -290,7 +313,9 @@ def record_transition(self, :param timesteps: Number of timesteps :type timesteps: int """ - super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps) + super().record_transition( + states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ) if self.memories: shared_states = infos["shared_states"] @@ -302,7 +327,9 @@ def record_transition(self, rewards[uid] = self._rewards_shaper(rewards[uid], timestep, timesteps) # compute values - values, _, _ = self.values[uid].act({"states": self._shared_state_preprocessor[uid](shared_states)}, role="value") + values, _, _ = self.values[uid].act( + {"states": self._shared_state_preprocessor[uid](shared_states)}, role="value" + ) values = self._value_preprocessor[uid](values, inverse=True) # time-limit (truncation) boostrapping @@ -310,9 +337,17 @@ def record_transition(self, rewards[uid] += self._discount_factor[uid] * values * truncated[uid] # storage transition in memory - self.memories[uid].add_samples(states=states[uid], actions=actions[uid], rewards=rewards[uid], next_states=next_states[uid], - terminated=terminated[uid], truncated=truncated[uid], log_prob=self._current_log_prob[uid], values=values, - shared_states=shared_states) + self.memories[uid].add_samples( + states=states[uid], + actions=actions[uid], + rewards=rewards[uid], + next_states=next_states[uid], + terminated=terminated[uid], + truncated=truncated[uid], + log_prob=self._current_log_prob[uid], + values=values, + shared_states=shared_states, + ) def pre_interaction(self, timestep: int, timesteps: int) -> None: """Callback called before the interaction with the environment @@ -349,12 +384,15 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - def compute_gae(rewards: torch.Tensor, - dones: torch.Tensor, - values: torch.Tensor, - next_values: torch.Tensor, - discount_factor: float = 0.99, - lambda_coefficient: float = 0.95) -> torch.Tensor: + + def compute_gae( + rewards: torch.Tensor, + dones: torch.Tensor, + values: torch.Tensor, + next_values: torch.Tensor, + discount_factor: float = 0.99, + lambda_coefficient: float = 0.95, + ) -> torch.Tensor: """Compute the Generalized Advantage Estimator (GAE) :param rewards: Rewards obtained by the agent @@ -381,7 +419,11 @@ def compute_gae(rewards: torch.Tensor, # advantages computation for i in reversed(range(memory_size)): next_values = values[i + 1] if i < memory_size - 1 else last_values - advantage = rewards[i] - values[i] + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + advantage = ( + rewards[i] + - values[i] + + discount_factor * not_dones[i] * (next_values + lambda_coefficient * advantage) + ) advantages[i] = advantage # returns computation returns = advantages + values @@ -398,17 +440,22 @@ def compute_gae(rewards: torch.Tensor, # compute returns and advantages with torch.no_grad(): value.train(False) - last_values, _, _ = value.act({"states": self._shared_state_preprocessor[uid](self._current_shared_next_states.float())}, role="value") + last_values, _, _ = value.act( + {"states": self._shared_state_preprocessor[uid](self._current_shared_next_states.float())}, + role="value", + ) value.train(True) last_values = self._value_preprocessor[uid](last_values, inverse=True) values = memory.get_tensor_by_name("values") - returns, advantages = compute_gae(rewards=memory.get_tensor_by_name("rewards"), - dones=memory.get_tensor_by_name("terminated"), - values=values, - next_values=last_values, - discount_factor=self._discount_factor[uid], - lambda_coefficient=self._lambda[uid]) + returns, advantages = compute_gae( + rewards=memory.get_tensor_by_name("rewards"), + dones=memory.get_tensor_by_name("terminated"), + values=values, + next_values=last_values, + discount_factor=self._discount_factor[uid], + lambda_coefficient=self._lambda[uid], + ) memory.set_tensor_by_name("values", self._value_preprocessor[uid](values, train=True)) memory.set_tensor_by_name("returns", self._value_preprocessor[uid](returns, train=True)) @@ -426,13 +473,22 @@ def compute_gae(rewards: torch.Tensor, kl_divergences = [] # mini-batches loop - for sampled_states, sampled_shared_states, sampled_actions, sampled_log_prob, sampled_values, sampled_returns, sampled_advantages \ - in sampled_batches: + for ( + sampled_states, + sampled_shared_states, + sampled_actions, + sampled_log_prob, + sampled_values, + sampled_returns, + sampled_advantages, + ) in sampled_batches: sampled_states = self._state_preprocessor[uid](sampled_states, train=not epoch) sampled_shared_states = self._shared_state_preprocessor[uid](sampled_shared_states, train=not epoch) - _, next_log_prob, _ = policy.act({"states": sampled_states, "taken_actions": sampled_actions}, role="policy") + _, next_log_prob, _ = policy.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="policy" + ) # compute approximate KL divergence with torch.no_grad(): @@ -453,7 +509,9 @@ def compute_gae(rewards: torch.Tensor, # compute policy loss ratio = torch.exp(next_log_prob - sampled_log_prob) surrogate = sampled_advantages * ratio - surrogate_clipped = sampled_advantages * torch.clip(ratio, 1.0 - self._ratio_clip[uid], 1.0 + self._ratio_clip[uid]) + surrogate_clipped = sampled_advantages * torch.clip( + ratio, 1.0 - self._ratio_clip[uid], 1.0 + self._ratio_clip[uid] + ) policy_loss = -torch.min(surrogate, surrogate_clipped).mean() @@ -461,9 +519,9 @@ def compute_gae(rewards: torch.Tensor, predicted_values, _, _ = value.act({"states": sampled_shared_states}, role="value") if self._clip_predicted_values: - predicted_values = sampled_values + torch.clip(predicted_values - sampled_values, - min=-self._value_clip[uid], - max=self._value_clip[uid]) + predicted_values = sampled_values + torch.clip( + predicted_values - sampled_values, min=-self._value_clip[uid], max=self._value_clip[uid] + ) value_loss = self._value_loss_scale[uid] * F.mse_loss(sampled_returns, predicted_values) # optimization step @@ -477,7 +535,9 @@ def compute_gae(rewards: torch.Tensor, if policy is value: nn.utils.clip_grad_norm_(policy.parameters(), self._grad_norm_clip[uid]) else: - nn.utils.clip_grad_norm_(itertools.chain(policy.parameters(), value.parameters()), self._grad_norm_clip[uid]) + nn.utils.clip_grad_norm_( + itertools.chain(policy.parameters(), value.parameters()), self._grad_norm_clip[uid] + ) self.optimizers[uid].step() # update cumulative losses @@ -499,12 +559,23 @@ def compute_gae(rewards: torch.Tensor, self.schedulers[uid].step() # record data - self.track_data(f"Loss / Policy loss ({uid})", cumulative_policy_loss / (self._learning_epochs[uid] * self._mini_batches[uid])) - self.track_data(f"Loss / Value loss ({uid})", cumulative_value_loss / (self._learning_epochs[uid] * self._mini_batches[uid])) + self.track_data( + f"Loss / Policy loss ({uid})", + cumulative_policy_loss / (self._learning_epochs[uid] * self._mini_batches[uid]), + ) + self.track_data( + f"Loss / Value loss ({uid})", + cumulative_value_loss / (self._learning_epochs[uid] * self._mini_batches[uid]), + ) if self._entropy_loss_scale: - self.track_data(f"Loss / Entropy loss ({uid})", cumulative_entropy_loss / (self._learning_epochs[uid] * self._mini_batches[uid])) - - self.track_data(f"Policy / Standard deviation ({uid})", policy.distribution(role="policy").stddev.mean().item()) + self.track_data( + f"Loss / Entropy loss ({uid})", + cumulative_entropy_loss / (self._learning_epochs[uid] * self._mini_batches[uid]), + ) + + self.track_data( + f"Policy / Standard deviation ({uid})", policy.distribution(role="policy").stddev.mean().item() + ) if self._learning_rate_scheduler[uid]: self.track_data(f"Learning / Learning rate ({uid})", self.schedulers[uid].get_last_lr()[0]) diff --git a/skrl/resources/noises/jax/base.py b/skrl/resources/noises/jax/base.py index e51e6fd3..2cd60e0e 100644 --- a/skrl/resources/noises/jax/base.py +++ b/skrl/resources/noises/jax/base.py @@ -6,7 +6,7 @@ from skrl import config -class Noise(): +class Noise: def __init__(self, device: Optional[Union[str, jax.Device]] = None) -> None: """Base class representing a noise @@ -33,7 +33,7 @@ def sample(self, size): else: self.device = device if type(device) == str: - device_type, device_index = f"{device}:0".split(':')[:2] + device_type, device_index = f"{device}:0".split(":")[:2] self.device = jax.devices(device_type)[int(device_index)] def sample_like(self, tensor: Union[np.ndarray, jax.Array]) -> Union[np.ndarray, jax.Array]: diff --git a/skrl/resources/noises/jax/ornstein_uhlenbeck.py b/skrl/resources/noises/jax/ornstein_uhlenbeck.py index a9a64679..58d1fca5 100644 --- a/skrl/resources/noises/jax/ornstein_uhlenbeck.py +++ b/skrl/resources/noises/jax/ornstein_uhlenbeck.py @@ -18,13 +18,15 @@ def _sample(theta, sigma, state, mean, std, key, iterator, shape): class OrnsteinUhlenbeckNoise(Noise): - def __init__(self, - theta: float, - sigma: float, - base_scale: float, - mean: float = 0, - std: float = 1, - device: Optional[Union[str, jax.Device]] = None) -> None: + def __init__( + self, + theta: float, + sigma: float, + base_scale: float, + mean: float = 0, + std: float = 1, + device: Optional[Union[str, jax.Device]] = None, + ) -> None: """Class representing an Ornstein-Uhlenbeck noise :param theta: Factor to apply to current internal state diff --git a/skrl/resources/noises/torch/base.py b/skrl/resources/noises/torch/base.py index 90344961..ba83051f 100644 --- a/skrl/resources/noises/torch/base.py +++ b/skrl/resources/noises/torch/base.py @@ -3,7 +3,7 @@ import torch -class Noise(): +class Noise: def __init__(self, device: Optional[Union[str, torch.device]] = None) -> None: """Base class representing a noise diff --git a/skrl/resources/noises/torch/gaussian.py b/skrl/resources/noises/torch/gaussian.py index 6356bdde..cf990d49 100644 --- a/skrl/resources/noises/torch/gaussian.py +++ b/skrl/resources/noises/torch/gaussian.py @@ -24,8 +24,10 @@ def __init__(self, mean: float, std: float, device: Optional[Union[str, torch.de """ super().__init__(device) - self.distribution = Normal(loc=torch.tensor(mean, device=self.device, dtype=torch.float32), - scale=torch.tensor(std, device=self.device, dtype=torch.float32)) + self.distribution = Normal( + loc=torch.tensor(mean, device=self.device, dtype=torch.float32), + scale=torch.tensor(std, device=self.device, dtype=torch.float32), + ) def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor: """Sample a Gaussian noise diff --git a/skrl/resources/noises/torch/ornstein_uhlenbeck.py b/skrl/resources/noises/torch/ornstein_uhlenbeck.py index cd73d621..e3338a17 100644 --- a/skrl/resources/noises/torch/ornstein_uhlenbeck.py +++ b/skrl/resources/noises/torch/ornstein_uhlenbeck.py @@ -7,13 +7,15 @@ class OrnsteinUhlenbeckNoise(Noise): - def __init__(self, - theta: float, - sigma: float, - base_scale: float, - mean: float = 0, - std: float = 1, - device: Optional[Union[str, torch.device]] = None) -> None: + def __init__( + self, + theta: float, + sigma: float, + base_scale: float, + mean: float = 0, + std: float = 1, + device: Optional[Union[str, torch.device]] = None, + ) -> None: """Class representing an Ornstein-Uhlenbeck noise :param theta: Factor to apply to current internal state @@ -41,8 +43,10 @@ def __init__(self, self.sigma = sigma self.base_scale = base_scale - self.distribution = Normal(loc=torch.tensor(mean, device=self.device, dtype=torch.float32), - scale=torch.tensor(std, device=self.device, dtype=torch.float32)) + self.distribution = Normal( + loc=torch.tensor(mean, device=self.device, dtype=torch.float32), + scale=torch.tensor(std, device=self.device, dtype=torch.float32), + ) def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor: """Sample an Ornstein-Uhlenbeck noise diff --git a/skrl/resources/optimizers/jax/adam.py b/skrl/resources/optimizers/jax/adam.py index 767a364f..f8b405d1 100644 --- a/skrl/resources/optimizers/jax/adam.py +++ b/skrl/resources/optimizers/jax/adam.py @@ -63,6 +63,7 @@ def __new__(cls, model: Model, lr: float = 1e-3, grad_norm_clip: float = 0, scal >>> # step the optimizer given a computed gradiend and an updated learning rate (lr) >>> optimizer = optimizer.step(grad, policy, lr) """ + class Optimizer(flax.struct.PyTreeNode): """Optimizer @@ -71,6 +72,7 @@ class Optimizer(flax.struct.PyTreeNode): https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#train-state """ + transformation: optax.GradientTransformation = flax.struct.field(pytree_node=False) state: optax.OptState = flax.struct.field(pytree_node=True) @@ -95,7 +97,9 @@ def step(self, grad: jax.Array, model: Model, lr: Optional[float] = None) -> "Op if lr is None: optimizer_state, model.state_dict = _step(self.transformation, grad, self.state, model.state_dict) else: - optimizer_state, model.state_dict = _step_with_scale(self.transformation, grad, self.state, model.state_dict, -lr) + optimizer_state, model.state_dict = _step_with_scale( + self.transformation, grad, self.state, model.state_dict, -lr + ) return self.replace(state=optimizer_state) # default optax transformation diff --git a/skrl/resources/preprocessors/jax/running_standard_scaler.py b/skrl/resources/preprocessors/jax/running_standard_scaler.py index 3563942e..1de08152 100644 --- a/skrl/resources/preprocessors/jax/running_standard_scaler.py +++ b/skrl/resources/preprocessors/jax/running_standard_scaler.py @@ -13,16 +13,14 @@ # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function @jax.jit def _copyto(dst, src): - """NumPy function copyto not yet implemented - """ + """NumPy function copyto not yet implemented""" return dst.at[:].set(src) @jax.jit -def _parallel_variance(running_mean: jax.Array, - running_variance: jax.Array, - current_count: jax.Array, - array: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]: # yapf: disable +def _parallel_variance( + running_mean: jax.Array, running_variance: jax.Array, current_count: jax.Array, array: jax.Array +) -> Tuple[jax.Array, jax.Array, jax.Array]: # yapf: disable # ddof = 1: https://github.com/pytorch/pytorch/issues/50010 if array.ndim == 3: input_mean = jnp.mean(array, axis=(0, 1)) @@ -35,35 +33,37 @@ def _parallel_variance(running_mean: jax.Array, delta = input_mean - running_mean total_count = current_count + input_count - M2 = (running_variance * current_count) + (input_var * input_count) \ - + delta ** 2 * current_count * input_count / total_count + M2 = ( + (running_variance * current_count) + + (input_var * input_count) + + delta**2 * current_count * input_count / total_count + ) return running_mean + delta * input_count / total_count, M2 / total_count, total_count @jax.jit -def _inverse(running_mean: jax.Array, - running_variance: jax.Array, - clip_threshold: float, - array: jax.Array) -> jax.Array: # yapf: disable +def _inverse( + running_mean: jax.Array, running_variance: jax.Array, clip_threshold: float, array: jax.Array +) -> jax.Array: # yapf: disable return jnp.sqrt(running_variance) * jnp.clip(array, -clip_threshold, clip_threshold) + running_mean @jax.jit -def _standardization(running_mean: jax.Array, - running_variance: jax.Array, - clip_threshold: float, - epsilon: float, - array: jax.Array) -> jax.Array: +def _standardization( + running_mean: jax.Array, running_variance: jax.Array, clip_threshold: float, epsilon: float, array: jax.Array +) -> jax.Array: return jnp.clip((array - running_mean) / (jnp.sqrt(running_variance) + epsilon), -clip_threshold, clip_threshold) class RunningStandardScaler: - def __init__(self, - size: Union[int, Tuple[int], gymnasium.Space], - epsilon: float = 1e-8, - clip_threshold: float = 5.0, - device: Optional[Union[str, jax.Device]] = None) -> None: + def __init__( + self, + size: Union[int, Tuple[int], gymnasium.Space], + epsilon: float = 1e-8, + clip_threshold: float = 5.0, + device: Optional[Union[str, jax.Device]] = None, + ) -> None: """Standardize the input data by removing the mean and scaling by the standard deviation The implementation is adapted from the rl_games library @@ -97,7 +97,7 @@ def __init__(self, else: self.device = device if type(device) == str: - device_type, device_index = f"{device}:0".split(':')[:2] + device_type, device_index = f"{device}:0".split(":")[:2] self.device = jax.devices(device_type)[int(device_index)] size = compute_space_size(size, occupied_size=True) @@ -114,8 +114,8 @@ def __init__(self, @property def state_dict(self) -> Mapping[str, Union[np.ndarray, jax.Array]]: - """Dictionary containing references to the whole state of the module - """ + """Dictionary containing references to the whole state of the module""" + class _StateDict: def __init__(self, params): self.params = params @@ -123,11 +123,13 @@ def __init__(self, params): def replace(self, params): return params - return _StateDict({ - "running_mean": self.running_mean, - "running_variance": self.running_variance, - "current_count": self.current_count - }) + return _StateDict( + { + "running_mean": self.running_mean, + "running_variance": self.running_variance, + "current_count": self.current_count, + } + ) @state_dict.setter def state_dict(self, value: Mapping[str, Union[np.ndarray, jax.Array]]) -> None: @@ -140,10 +142,9 @@ def state_dict(self, value: Mapping[str, Union[np.ndarray, jax.Array]]) -> None: np.copyto(self.running_variance, value["running_variance"]) np.copyto(self.current_count, value["current_count"]) - def _parallel_variance(self, - input_mean: Union[np.ndarray, jax.Array], - input_var: Union[np.ndarray, jax.Array], - input_count: int) -> None: + def _parallel_variance( + self, input_mean: Union[np.ndarray, jax.Array], input_var: Union[np.ndarray, jax.Array], input_count: int + ) -> None: """Update internal variables using the parallel algorithm for computing variance https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm @@ -157,18 +158,20 @@ def _parallel_variance(self, """ delta = input_mean - self.running_mean total_count = self.current_count + input_count - M2 = (self.running_variance * self.current_count) + (input_var * input_count) \ - + delta ** 2 * self.current_count * input_count / total_count + M2 = ( + (self.running_variance * self.current_count) + + (input_var * input_count) + + delta**2 * self.current_count * input_count / total_count + ) # update internal variables self.running_mean = self.running_mean + delta * input_count / total_count self.running_variance = M2 / total_count self.current_count = total_count - def __call__(self, - x: Union[np.ndarray, jax.Array], - train: bool = False, - inverse: bool = False) -> Union[np.ndarray, jax.Array]: + def __call__( + self, x: Union[np.ndarray, jax.Array], train: bool = False, inverse: bool = False + ) -> Union[np.ndarray, jax.Array]: """Forward pass of the standardizer Example:: @@ -201,14 +204,15 @@ def __call__(self, """ if train: if self._jax: - self.running_mean, self.running_variance, self.current_count = \ - _parallel_variance(self.running_mean, self.running_variance, self.current_count, x) + self.running_mean, self.running_variance, self.current_count = _parallel_variance( + self.running_mean, self.running_variance, self.current_count, x + ) else: # ddof = 1: https://github.com/pytorch/pytorch/issues/50010 if x.ndim == 3: - self._parallel_variance(np.mean(x, axis=(0, 1)), - np.var(x, axis=(0, 1), ddof=1), - x.shape[0] * x.shape[1]) + self._parallel_variance( + np.mean(x, axis=(0, 1)), np.var(x, axis=(0, 1), ddof=1), x.shape[0] * x.shape[1] + ) else: self._parallel_variance(np.mean(x, axis=0), np.var(x, axis=0, ddof=1), x.shape[0]) @@ -216,11 +220,15 @@ def __call__(self, if inverse: if self._jax: return _inverse(self.running_mean, self.running_variance, self.clip_threshold, x) - return np.sqrt(self.running_variance) * np.clip(x, -self.clip_threshold, - self.clip_threshold) + self.running_mean + return ( + np.sqrt(self.running_variance) * np.clip(x, -self.clip_threshold, self.clip_threshold) + + self.running_mean + ) # standardization by centering and scaling if self._jax: return _standardization(self.running_mean, self.running_variance, self.clip_threshold, self.epsilon, x) - return np.clip((x - self.running_mean) / (np.sqrt(self.running_variance) + self.epsilon), - a_min=-self.clip_threshold, - a_max=self.clip_threshold) + return np.clip( + (x - self.running_mean) / (np.sqrt(self.running_variance) + self.epsilon), + a_min=-self.clip_threshold, + a_max=self.clip_threshold, + ) diff --git a/skrl/resources/preprocessors/torch/running_standard_scaler.py b/skrl/resources/preprocessors/torch/running_standard_scaler.py index 43f5cfda..17f066f8 100644 --- a/skrl/resources/preprocessors/torch/running_standard_scaler.py +++ b/skrl/resources/preprocessors/torch/running_standard_scaler.py @@ -9,11 +9,13 @@ class RunningStandardScaler(nn.Module): - def __init__(self, - size: Union[int, Tuple[int], gymnasium.Space], - epsilon: float = 1e-8, - clip_threshold: float = 5.0, - device: Optional[Union[str, torch.device]] = None) -> None: + def __init__( + self, + size: Union[int, Tuple[int], gymnasium.Space], + epsilon: float = 1e-8, + clip_threshold: float = 5.0, + device: Optional[Union[str, torch.device]] = None, + ) -> None: """Standardize the input data by removing the mean and scaling by the standard deviation The implementation is adapted from the rl_games library @@ -67,8 +69,11 @@ def _parallel_variance(self, input_mean: torch.Tensor, input_var: torch.Tensor, """ delta = input_mean - self.running_mean total_count = self.current_count + input_count - M2 = (self.running_variance * self.current_count) + (input_var * input_count) \ - + delta ** 2 * self.current_count * input_count / total_count + M2 = ( + (self.running_variance * self.current_count) + + (input_var * input_count) + + delta**2 * self.current_count * input_count / total_count + ) # update internal variables self.running_mean = self.running_mean + delta * input_count / total_count @@ -96,18 +101,21 @@ def _compute(self, x: torch.Tensor, train: bool = False, inverse: bool = False) # scale back the data to the original representation if inverse: - return torch.sqrt(self.running_variance.float()) \ - * torch.clamp(x, min=-self.clip_threshold, max=self.clip_threshold) + self.running_mean.float() + return ( + torch.sqrt(self.running_variance.float()) + * torch.clamp(x, min=-self.clip_threshold, max=self.clip_threshold) + + self.running_mean.float() + ) # standardization by centering and scaling - return torch.clamp((x - self.running_mean.float()) / (torch.sqrt(self.running_variance.float()) + self.epsilon), - min=-self.clip_threshold, - max=self.clip_threshold) - - def forward(self, - x: torch.Tensor, - train: bool = False, - inverse: bool = False, - no_grad: bool = True) -> torch.Tensor: + return torch.clamp( + (x - self.running_mean.float()) / (torch.sqrt(self.running_variance.float()) + self.epsilon), + min=-self.clip_threshold, + max=self.clip_threshold, + ) + + def forward( + self, x: torch.Tensor, train: bool = False, inverse: bool = False, no_grad: bool = True + ) -> torch.Tensor: """Forward pass of the standardizer Example:: diff --git a/skrl/resources/schedulers/jax/kl_adaptive.py b/skrl/resources/schedulers/jax/kl_adaptive.py index dbf2c4c6..149bc922 100644 --- a/skrl/resources/schedulers/jax/kl_adaptive.py +++ b/skrl/resources/schedulers/jax/kl_adaptive.py @@ -4,13 +4,15 @@ class KLAdaptiveLR: - def __init__(self, - init_value: float, - kl_threshold: float = 0.008, - min_lr: float = 1e-6, - max_lr: float = 1e-2, - kl_factor: float = 2, - lr_factor: float = 1.5) -> None: + def __init__( + self, + init_value: float, + kl_threshold: float = 0.008, + min_lr: float = 1e-6, + max_lr: float = 1e-2, + kl_factor: float = 2, + lr_factor: float = 1.5, + ) -> None: """Adaptive KL scheduler Adjusts the learning rate according to the KL divergence. @@ -54,8 +56,7 @@ def __init__(self, @property def lr(self) -> float: - """Learning rate - """ + """Learning rate""" return self._lr def step(self, kl: Optional[Union[np.ndarray, float]] = None) -> None: diff --git a/skrl/resources/schedulers/torch/kl_adaptive.py b/skrl/resources/schedulers/torch/kl_adaptive.py index 0a32eed3..c63f1563 100644 --- a/skrl/resources/schedulers/torch/kl_adaptive.py +++ b/skrl/resources/schedulers/torch/kl_adaptive.py @@ -7,15 +7,17 @@ class KLAdaptiveLR(_LRScheduler): - def __init__(self, - optimizer: torch.optim.Optimizer, - kl_threshold: float = 0.008, - min_lr: float = 1e-6, - max_lr: float = 1e-2, - kl_factor: float = 2, - lr_factor: float = 1.5, - last_epoch: int = -1, - verbose: bool = False) -> None: + def __init__( + self, + optimizer: torch.optim.Optimizer, + kl_threshold: float = 0.008, + min_lr: float = 1e-6, + max_lr: float = 1e-2, + kl_factor: float = 2, + lr_factor: float = 1.5, + last_epoch: int = -1, + verbose: bool = False, + ) -> None: """Adaptive KL scheduler Adjusts the learning rate according to the KL divergence. @@ -62,7 +64,7 @@ def __init__(self, self._kl_factor = kl_factor self._lr_factor = lr_factor - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] def step(self, kl: Optional[Union[torch.Tensor, float]] = None, epoch: Optional[int] = None) -> None: """ @@ -88,8 +90,8 @@ def step(self, kl: Optional[Union[torch.Tensor, float]] = None, epoch: Optional[ if kl is not None: for group in self.optimizer.param_groups: if kl > self.kl_threshold * self._kl_factor: - group['lr'] = max(group['lr'] / self._lr_factor, self.min_lr) + group["lr"] = max(group["lr"] / self._lr_factor, self.min_lr) elif kl < self.kl_threshold / self._kl_factor: - group['lr'] = min(group['lr'] * self._lr_factor, self.max_lr) + group["lr"] = min(group["lr"] * self._lr_factor, self.max_lr) - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] diff --git a/skrl/trainers/jax/base.py b/skrl/trainers/jax/base.py index a9205e56..2e086820 100644 --- a/skrl/trainers/jax/base.py +++ b/skrl/trainers/jax/base.py @@ -27,16 +27,20 @@ def generate_equally_spaced_scopes(num_envs: int, num_simultaneous_agents: int) if sum(scopes): scopes[-1] += num_envs - sum(scopes) else: - raise ValueError(f"The number of simultaneous agents ({num_simultaneous_agents}) is greater than the number of environments ({num_envs})") + raise ValueError( + f"The number of simultaneous agents ({num_simultaneous_agents}) is greater than the number of environments ({num_envs})" + ) return scopes class Trainer: - def __init__(self, - env: Wrapper, - agents: Union[Agent, List[Agent]], - agents_scope: Optional[List[int]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + env: Wrapper, + agents: Union[Agent, List[Agent]], + agents_scope: Optional[List[int]] = None, + cfg: Optional[dict] = None, + ) -> None: """Base class for trainers :param env: Environment to train on @@ -68,6 +72,7 @@ def __init__(self, # register environment closing if configured if self.close_environment_at_exit: + @atexit.register def close_env(): logger.info("Closing environment") @@ -120,11 +125,17 @@ def _setup_agents(self) -> None: if sum(self.agents_scope): self.agents_scope[-1] += self.env.num_envs - sum(self.agents_scope) else: - raise ValueError(f"The number of agents ({len(self.agents)}) is greater than the number of parallelizable environments ({self.env.num_envs})") + raise ValueError( + f"The number of agents ({len(self.agents)}) is greater than the number of parallelizable environments ({self.env.num_envs})" + ) elif len(self.agents_scope) != len(self.agents): - raise ValueError(f"The number of agents ({len(self.agents)}) doesn't match the number of scopes ({len(self.agents_scope)})") + raise ValueError( + f"The number of agents ({len(self.agents)}) doesn't match the number of scopes ({len(self.agents_scope)})" + ) elif sum(self.agents_scope) != self.env.num_envs: - raise ValueError(f"The scopes ({sum(self.agents_scope)}) don't cover the number of parallelizable environments ({self.env.num_envs})") + raise ValueError( + f"The scopes ({sum(self.agents_scope)}) don't cover the number of parallelizable environments ({self.env.num_envs})" + ) # generate agents' scopes index = 0 for i in range(len(self.agents_scope)): @@ -168,7 +179,9 @@ def single_agent_train(self) -> None: # reset env states, infos = self.env.reset() - for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): + for timestep in tqdm.tqdm( + range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout + ): # pre-interaction self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) @@ -185,15 +198,17 @@ def single_agent_train(self) -> None: self.env.render() # record the environments' transitions - self.agents.record_transition(states=states, - actions=actions, - rewards=rewards, - next_states=next_states, - terminated=terminated, - truncated=truncated, - infos=infos, - timestep=timestep, - timesteps=self.timesteps) + self.agents.record_transition( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + infos=infos, + timestep=timestep, + timesteps=self.timesteps, + ) # post-interaction self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps) @@ -224,7 +239,9 @@ def single_agent_eval(self) -> None: # reset env states, infos = self.env.reset() - for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): + for timestep in tqdm.tqdm( + range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout + ): # pre-interaction self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) @@ -241,15 +258,17 @@ def single_agent_eval(self) -> None: self.env.render() # write data to TensorBoard - self.agents.record_transition(states=states, - actions=actions, - rewards=rewards, - next_states=next_states, - terminated=terminated, - truncated=truncated, - infos=infos, - timestep=timestep, - timesteps=self.timesteps) + self.agents.record_transition( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + infos=infos, + timestep=timestep, + timesteps=self.timesteps, + ) # post-interaction super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) @@ -284,7 +303,9 @@ def multi_agent_train(self) -> None: states, infos = self.env.reset() shared_states = self.env.state() - for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): + for timestep in tqdm.tqdm( + range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout + ): # pre-interaction self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) @@ -304,15 +325,17 @@ def multi_agent_train(self) -> None: self.env.render() # record the environments' transitions - self.agents.record_transition(states=states, - actions=actions, - rewards=rewards, - next_states=next_states, - terminated=terminated, - truncated=truncated, - infos=infos, - timestep=timestep, - timesteps=self.timesteps) + self.agents.record_transition( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + infos=infos, + timestep=timestep, + timesteps=self.timesteps, + ) # post-interaction self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps) @@ -343,7 +366,9 @@ def multi_agent_eval(self) -> None: states, infos = self.env.reset() shared_states = self.env.state() - for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): + for timestep in tqdm.tqdm( + range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout + ): # pre-interaction self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) @@ -363,15 +388,17 @@ def multi_agent_eval(self) -> None: self.env.render() # write data to TensorBoard - self.agents.record_transition(states=states, - actions=actions, - rewards=rewards, - next_states=next_states, - terminated=terminated, - truncated=truncated, - infos=infos, - timestep=timestep, - timesteps=self.timesteps) + self.agents.record_transition( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + infos=infos, + timestep=timestep, + timesteps=self.timesteps, + ) # post-interaction super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) diff --git a/skrl/trainers/jax/sequential.py b/skrl/trainers/jax/sequential.py index d35c3179..5e690d91 100644 --- a/skrl/trainers/jax/sequential.py +++ b/skrl/trainers/jax/sequential.py @@ -26,11 +26,13 @@ class SequentialTrainer(Trainer): - def __init__(self, - env: Wrapper, - agents: Union[Agent, List[Agent]], - agents_scope: Optional[List[int]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + env: Wrapper, + agents: Union[Agent, List[Agent]], + agents_scope: Optional[List[int]] = None, + cfg: Optional[dict] = None, + ) -> None: """Sequential trainer Train agents sequentially (i.e., one after the other in each interaction with the environment) @@ -90,7 +92,9 @@ def train(self) -> None: # reset env states, infos = self.env.reset() - for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): + for timestep in tqdm.tqdm( + range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout + ): # pre-interaction for agent in self.agents: @@ -98,8 +102,12 @@ def train(self) -> None: with contextlib.nullcontext(): # compute actions - actions = jnp.vstack([agent.act(states[scope[0]:scope[1]], timestep=timestep, timesteps=self.timesteps)[0] \ - for agent, scope in zip(self.agents, self.agents_scope)]) + actions = jnp.vstack( + [ + agent.act(states[scope[0] : scope[1]], timestep=timestep, timesteps=self.timesteps)[0] + for agent, scope in zip(self.agents, self.agents_scope) + ] + ) # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) @@ -110,15 +118,17 @@ def train(self) -> None: # record the environments' transitions for agent, scope in zip(self.agents, self.agents_scope): - agent.record_transition(states=states[scope[0]:scope[1]], - actions=actions[scope[0]:scope[1]], - rewards=rewards[scope[0]:scope[1]], - next_states=next_states[scope[0]:scope[1]], - terminated=terminated[scope[0]:scope[1]], - truncated=truncated[scope[0]:scope[1]], - infos=infos, - timestep=timestep, - timesteps=self.timesteps) + agent.record_transition( + states=states[scope[0] : scope[1]], + actions=actions[scope[0] : scope[1]], + rewards=rewards[scope[0] : scope[1]], + next_states=next_states[scope[0] : scope[1]], + terminated=terminated[scope[0] : scope[1]], + truncated=truncated[scope[0] : scope[1]], + infos=infos, + timestep=timestep, + timesteps=self.timesteps, + ) # post-interaction for agent in self.agents: @@ -161,7 +171,9 @@ def eval(self) -> None: # reset env states, infos = self.env.reset() - for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): + for timestep in tqdm.tqdm( + range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout + ): # pre-interaction for agent in self.agents: @@ -169,8 +181,12 @@ def eval(self) -> None: with contextlib.nullcontext(): # compute actions - actions = jnp.vstack([agent.act(states[scope[0]:scope[1]], timestep=timestep, timesteps=self.timesteps)[0] \ - for agent, scope in zip(self.agents, self.agents_scope)]) + actions = jnp.vstack( + [ + agent.act(states[scope[0] : scope[1]], timestep=timestep, timesteps=self.timesteps)[0] + for agent, scope in zip(self.agents, self.agents_scope) + ] + ) # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) @@ -181,15 +197,17 @@ def eval(self) -> None: # write data to TensorBoard for agent, scope in zip(self.agents, self.agents_scope): - agent.record_transition(states=states[scope[0]:scope[1]], - actions=actions[scope[0]:scope[1]], - rewards=rewards[scope[0]:scope[1]], - next_states=next_states[scope[0]:scope[1]], - terminated=terminated[scope[0]:scope[1]], - truncated=truncated[scope[0]:scope[1]], - infos=infos, - timestep=timestep, - timesteps=self.timesteps) + agent.record_transition( + states=states[scope[0] : scope[1]], + actions=actions[scope[0] : scope[1]], + rewards=rewards[scope[0] : scope[1]], + next_states=next_states[scope[0] : scope[1]], + terminated=terminated[scope[0] : scope[1]], + truncated=truncated[scope[0] : scope[1]], + infos=infos, + timestep=timestep, + timesteps=self.timesteps, + ) # post-interaction for agent in self.agents: diff --git a/skrl/trainers/jax/step.py b/skrl/trainers/jax/step.py index e164ccde..0bb3b361 100644 --- a/skrl/trainers/jax/step.py +++ b/skrl/trainers/jax/step.py @@ -28,11 +28,13 @@ class StepTrainer(Trainer): - def __init__(self, - env: Wrapper, - agents: Union[Agent, List[Agent]], - agents_scope: Optional[List[int]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + env: Wrapper, + agents: Union[Agent, List[Agent]], + agents_scope: Optional[List[int]] = None, + cfg: Optional[dict] = None, + ) -> None: """Step-by-step trainer Train agents by controlling the training/evaluation loop step by step @@ -64,9 +66,13 @@ def __init__(self, self.states = None - def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> \ - Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], - Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Any]: + def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> Tuple[ + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Any, + ]: """Execute a training iteration This method executes the following steps once: @@ -116,8 +122,12 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) with contextlib.nullcontext(): # compute actions - actions = jnp.vstack([agent.act(self.states[scope[0]:scope[1]], timestep=timestep, timesteps=timesteps)[0] \ - for agent, scope in zip(self.agents, self.agents_scope)]) + actions = jnp.vstack( + [ + agent.act(self.states[scope[0] : scope[1]], timestep=timestep, timesteps=timesteps)[0] + for agent, scope in zip(self.agents, self.agents_scope) + ] + ) # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) @@ -128,15 +138,17 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) # record the environments' transitions for agent, scope in zip(self.agents, self.agents_scope): - agent.record_transition(states=self.states[scope[0]:scope[1]], - actions=actions[scope[0]:scope[1]], - rewards=rewards[scope[0]:scope[1]], - next_states=next_states[scope[0]:scope[1]], - terminated=terminated[scope[0]:scope[1]], - truncated=truncated[scope[0]:scope[1]], - infos=infos, - timestep=timestep, - timesteps=timesteps) + agent.record_transition( + states=self.states[scope[0] : scope[1]], + actions=actions[scope[0] : scope[1]], + rewards=rewards[scope[0] : scope[1]], + next_states=next_states[scope[0] : scope[1]], + terminated=terminated[scope[0] : scope[1]], + truncated=truncated[scope[0] : scope[1]], + infos=infos, + timestep=timestep, + timesteps=timesteps, + ) # post-interaction for agent in self.agents: @@ -151,9 +163,13 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) return next_states, rewards, terminated, truncated, infos - def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> \ - Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], - Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Any]: + def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> Tuple[ + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Union[np.ndarray, jax.Array], + Any, + ]: """Evaluate the agents sequentially This method executes the following steps in loop: @@ -200,8 +216,12 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) with contextlib.nullcontext(): # compute actions - actions = jnp.vstack([agent.act(self.states[scope[0]:scope[1]], timestep=timestep, timesteps=timesteps)[0] \ - for agent, scope in zip(self.agents, self.agents_scope)]) + actions = jnp.vstack( + [ + agent.act(self.states[scope[0] : scope[1]], timestep=timestep, timesteps=timesteps)[0] + for agent, scope in zip(self.agents, self.agents_scope) + ] + ) # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) @@ -212,15 +232,17 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) # write data to TensorBoard for agent, scope in zip(self.agents, self.agents_scope): - agent.record_transition(states=self.states[scope[0]:scope[1]], - actions=actions[scope[0]:scope[1]], - rewards=rewards[scope[0]:scope[1]], - next_states=next_states[scope[0]:scope[1]], - terminated=terminated[scope[0]:scope[1]], - truncated=truncated[scope[0]:scope[1]], - infos=infos, - timestep=timestep, - timesteps=timesteps) + agent.record_transition( + states=self.states[scope[0] : scope[1]], + actions=actions[scope[0] : scope[1]], + rewards=rewards[scope[0] : scope[1]], + next_states=next_states[scope[0] : scope[1]], + terminated=terminated[scope[0] : scope[1]], + truncated=truncated[scope[0] : scope[1]], + infos=infos, + timestep=timestep, + timesteps=timesteps, + ) # post-interaction for agent in self.agents: diff --git a/skrl/trainers/torch/base.py b/skrl/trainers/torch/base.py index 3d3c607e..16b61161 100644 --- a/skrl/trainers/torch/base.py +++ b/skrl/trainers/torch/base.py @@ -28,16 +28,20 @@ def generate_equally_spaced_scopes(num_envs: int, num_simultaneous_agents: int) if sum(scopes): scopes[-1] += num_envs - sum(scopes) else: - raise ValueError(f"The number of simultaneous agents ({num_simultaneous_agents}) is greater than the number of environments ({num_envs})") + raise ValueError( + f"The number of simultaneous agents ({num_simultaneous_agents}) is greater than the number of environments ({num_envs})" + ) return scopes class Trainer: - def __init__(self, - env: Wrapper, - agents: Union[Agent, List[Agent]], - agents_scope: Optional[List[int]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + env: Wrapper, + agents: Union[Agent, List[Agent]], + agents_scope: Optional[List[int]] = None, + cfg: Optional[dict] = None, + ) -> None: """Base class for trainers :param env: Environment to train on @@ -69,6 +73,7 @@ def __init__(self, # register environment closing if configured if self.close_environment_at_exit: + @atexit.register def close_env(): logger.info("Closing environment") @@ -121,11 +126,17 @@ def _setup_agents(self) -> None: if sum(self.agents_scope): self.agents_scope[-1] += self.env.num_envs - sum(self.agents_scope) else: - raise ValueError(f"The number of agents ({len(self.agents)}) is greater than the number of parallelizable environments ({self.env.num_envs})") + raise ValueError( + f"The number of agents ({len(self.agents)}) is greater than the number of parallelizable environments ({self.env.num_envs})" + ) elif len(self.agents_scope) != len(self.agents): - raise ValueError(f"The number of agents ({len(self.agents)}) doesn't match the number of scopes ({len(self.agents_scope)})") + raise ValueError( + f"The number of agents ({len(self.agents)}) doesn't match the number of scopes ({len(self.agents_scope)})" + ) elif sum(self.agents_scope) != self.env.num_envs: - raise ValueError(f"The scopes ({sum(self.agents_scope)}) don't cover the number of parallelizable environments ({self.env.num_envs})") + raise ValueError( + f"The scopes ({sum(self.agents_scope)}) don't cover the number of parallelizable environments ({self.env.num_envs})" + ) # generate agents' scopes index = 0 for i in range(len(self.agents_scope)): @@ -169,7 +180,9 @@ def single_agent_train(self) -> None: # reset env states, infos = self.env.reset() - for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): + for timestep in tqdm.tqdm( + range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout + ): # pre-interaction self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) @@ -186,15 +199,17 @@ def single_agent_train(self) -> None: self.env.render() # record the environments' transitions - self.agents.record_transition(states=states, - actions=actions, - rewards=rewards, - next_states=next_states, - terminated=terminated, - truncated=truncated, - infos=infos, - timestep=timestep, - timesteps=self.timesteps) + self.agents.record_transition( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + infos=infos, + timestep=timestep, + timesteps=self.timesteps, + ) # log environment info if self.environment_info in infos: @@ -231,7 +246,9 @@ def single_agent_eval(self) -> None: # reset env states, infos = self.env.reset() - for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): + for timestep in tqdm.tqdm( + range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout + ): # pre-interaction self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) @@ -248,15 +265,17 @@ def single_agent_eval(self) -> None: self.env.render() # write data to TensorBoard - self.agents.record_transition(states=states, - actions=actions, - rewards=rewards, - next_states=next_states, - terminated=terminated, - truncated=truncated, - infos=infos, - timestep=timestep, - timesteps=self.timesteps) + self.agents.record_transition( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + infos=infos, + timestep=timestep, + timesteps=self.timesteps, + ) # log environment info if self.environment_info in infos: @@ -297,7 +316,9 @@ def multi_agent_train(self) -> None: states, infos = self.env.reset() shared_states = self.env.state() - for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): + for timestep in tqdm.tqdm( + range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout + ): # pre-interaction self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) @@ -317,15 +338,17 @@ def multi_agent_train(self) -> None: self.env.render() # record the environments' transitions - self.agents.record_transition(states=states, - actions=actions, - rewards=rewards, - next_states=next_states, - terminated=terminated, - truncated=truncated, - infos=infos, - timestep=timestep, - timesteps=self.timesteps) + self.agents.record_transition( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + infos=infos, + timestep=timestep, + timesteps=self.timesteps, + ) # log environment info if self.environment_info in infos: @@ -362,7 +385,9 @@ def multi_agent_eval(self) -> None: states, infos = self.env.reset() shared_states = self.env.state() - for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): + for timestep in tqdm.tqdm( + range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout + ): # pre-interaction self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) @@ -382,15 +407,17 @@ def multi_agent_eval(self) -> None: self.env.render() # write data to TensorBoard - self.agents.record_transition(states=states, - actions=actions, - rewards=rewards, - next_states=next_states, - terminated=terminated, - truncated=truncated, - infos=infos, - timestep=timestep, - timesteps=self.timesteps) + self.agents.record_transition( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + infos=infos, + timestep=timestep, + timesteps=self.timesteps, + ) # log environment info if self.environment_info in infos: diff --git a/skrl/trainers/torch/parallel.py b/skrl/trainers/torch/parallel.py index a31b3a9b..b409ae7f 100644 --- a/skrl/trainers/torch/parallel.py +++ b/skrl/trainers/torch/parallel.py @@ -43,14 +43,14 @@ def fn_processor(process_index, *args): while True: msg = pipe.recv() - task = msg['task'] + task = msg["task"] # terminate process - if task == 'terminate': + if task == "terminate": break # initialize agent - elif task == 'init': + elif task == "init": agent = queue.get() agent.init(trainer_cfg=trainer_cfg) print(f"[INFO] Processor {process_index}: init agent {type(agent).__name__} with scope {scope}") @@ -58,14 +58,14 @@ def fn_processor(process_index, *args): # execute agent's pre-interaction step elif task == "pre_interaction": - agent.pre_interaction(timestep=msg['timestep'], timesteps=msg['timesteps']) + agent.pre_interaction(timestep=msg["timestep"], timesteps=msg["timesteps"]) barrier.wait() # get agent's actions elif task == "act": - _states = queue.get()[scope[0]:scope[1]] + _states = queue.get()[scope[0] : scope[1]] with torch.no_grad(): - _actions = agent.act(_states, timestep=msg['timestep'], timesteps=msg['timesteps'])[0] + _actions = agent.act(_states, timestep=msg["timestep"], timesteps=msg["timesteps"])[0] if not _actions.is_cuda: _actions.share_memory_() queue.put(_actions) @@ -74,44 +74,50 @@ def fn_processor(process_index, *args): # record agent's experience elif task == "record_transition": with torch.no_grad(): - agent.record_transition(states=_states, - actions=_actions, - rewards=queue.get()[scope[0]:scope[1]], - next_states=queue.get()[scope[0]:scope[1]], - terminated=queue.get()[scope[0]:scope[1]], - truncated=queue.get()[scope[0]:scope[1]], - infos=queue.get(), - timestep=msg['timestep'], - timesteps=msg['timesteps']) + agent.record_transition( + states=_states, + actions=_actions, + rewards=queue.get()[scope[0] : scope[1]], + next_states=queue.get()[scope[0] : scope[1]], + terminated=queue.get()[scope[0] : scope[1]], + truncated=queue.get()[scope[0] : scope[1]], + infos=queue.get(), + timestep=msg["timestep"], + timesteps=msg["timesteps"], + ) barrier.wait() # execute agent's post-interaction step elif task == "post_interaction": - agent.post_interaction(timestep=msg['timestep'], timesteps=msg['timesteps']) + agent.post_interaction(timestep=msg["timestep"], timesteps=msg["timesteps"]) barrier.wait() # write data to TensorBoard (evaluation) elif task == "eval-record_transition-post_interaction": with torch.no_grad(): - agent.record_transition(states=_states, - actions=_actions, - rewards=queue.get()[scope[0]:scope[1]], - next_states=queue.get()[scope[0]:scope[1]], - terminated=queue.get()[scope[0]:scope[1]], - truncated=queue.get()[scope[0]:scope[1]], - infos=queue.get(), - timestep=msg['timestep'], - timesteps=msg['timesteps']) - super(type(agent), agent).post_interaction(timestep=msg['timestep'], timesteps=msg['timesteps']) + agent.record_transition( + states=_states, + actions=_actions, + rewards=queue.get()[scope[0] : scope[1]], + next_states=queue.get()[scope[0] : scope[1]], + terminated=queue.get()[scope[0] : scope[1]], + truncated=queue.get()[scope[0] : scope[1]], + infos=queue.get(), + timestep=msg["timestep"], + timesteps=msg["timesteps"], + ) + super(type(agent), agent).post_interaction(timestep=msg["timestep"], timesteps=msg["timesteps"]) barrier.wait() class ParallelTrainer(Trainer): - def __init__(self, - env: Wrapper, - agents: Union[Agent, List[Agent]], - agents_scope: Optional[List[int]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + env: Wrapper, + agents: Union[Agent, List[Agent]], + agents_scope: Optional[List[int]] = None, + cfg: Optional[dict] = None, + ) -> None: """Parallel trainer Train agents in parallel using multiple processes @@ -131,7 +137,7 @@ def __init__(self, agents_scope = agents_scope if agents_scope is not None else [] super().__init__(env=env, agents=agents, agents_scope=agents_scope, cfg=_cfg) - mp.set_start_method(method='spawn', force=True) + mp.set_start_method(method="spawn", force=True) def train(self) -> None: """Train the agents in parallel @@ -189,16 +195,16 @@ def train(self) -> None: # spawn and wait for all processes to start for i in range(self.num_simultaneous_agents): - process = mp.Process(target=fn_processor, - args=(i, consumer_pipes, queues, barrier, self.agents_scope, self.cfg), - daemon=True) + process = mp.Process( + target=fn_processor, args=(i, consumer_pipes, queues, barrier, self.agents_scope, self.cfg), daemon=True + ) processes.append(process) process.start() barrier.wait() # initialize agents for pipe, queue, agent in zip(producer_pipes, queues, self.agents): - pipe.send({'task': 'init'}) + pipe.send({"task": "init"}) queue.put(agent) barrier.wait() @@ -207,7 +213,9 @@ def train(self) -> None: if not states.is_cuda: states.share_memory_() - for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): + for timestep in tqdm.tqdm( + range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout + ): # pre-interaction for pipe in producer_pipes: @@ -325,16 +333,16 @@ def eval(self) -> None: # spawn and wait for all processes to start for i in range(self.num_simultaneous_agents): - process = mp.Process(target=fn_processor, - args=(i, consumer_pipes, queues, barrier, self.agents_scope, self.cfg), - daemon=True) + process = mp.Process( + target=fn_processor, args=(i, consumer_pipes, queues, barrier, self.agents_scope, self.cfg), daemon=True + ) processes.append(process) process.start() barrier.wait() # initialize agents for pipe, queue, agent in zip(producer_pipes, queues, self.agents): - pipe.send({'task': 'init'}) + pipe.send({"task": "init"}) queue.put(agent) barrier.wait() @@ -343,7 +351,9 @@ def eval(self) -> None: if not states.is_cuda: states.share_memory_() - for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): + for timestep in tqdm.tqdm( + range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout + ): # pre-interaction for pipe in producer_pipes: @@ -378,9 +388,13 @@ def eval(self) -> None: # post-interaction for pipe, queue in zip(producer_pipes, queues): - pipe.send({"task": "eval-record_transition-post_interaction", - "timestep": timestep, - "timesteps": self.timesteps}) + pipe.send( + { + "task": "eval-record_transition-post_interaction", + "timestep": timestep, + "timesteps": self.timesteps, + } + ) queue.put(rewards) queue.put(next_states) queue.put(terminated) diff --git a/skrl/trainers/torch/sequential.py b/skrl/trainers/torch/sequential.py index c2111229..8304b1fd 100644 --- a/skrl/trainers/torch/sequential.py +++ b/skrl/trainers/torch/sequential.py @@ -25,11 +25,13 @@ class SequentialTrainer(Trainer): - def __init__(self, - env: Wrapper, - agents: Union[Agent, List[Agent]], - agents_scope: Optional[List[int]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + env: Wrapper, + agents: Union[Agent, List[Agent]], + agents_scope: Optional[List[int]] = None, + cfg: Optional[dict] = None, + ) -> None: """Sequential trainer Train agents sequentially (i.e., one after the other in each interaction with the environment) @@ -89,7 +91,9 @@ def train(self) -> None: # reset env states, infos = self.env.reset() - for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): + for timestep in tqdm.tqdm( + range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout + ): # pre-interaction for agent in self.agents: @@ -97,8 +101,12 @@ def train(self) -> None: with torch.no_grad(): # compute actions - actions = torch.vstack([agent.act(states[scope[0]:scope[1]], timestep=timestep, timesteps=self.timesteps)[0] \ - for agent, scope in zip(self.agents, self.agents_scope)]) + actions = torch.vstack( + [ + agent.act(states[scope[0] : scope[1]], timestep=timestep, timesteps=self.timesteps)[0] + for agent, scope in zip(self.agents, self.agents_scope) + ] + ) # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) @@ -109,15 +117,17 @@ def train(self) -> None: # record the environments' transitions for agent, scope in zip(self.agents, self.agents_scope): - agent.record_transition(states=states[scope[0]:scope[1]], - actions=actions[scope[0]:scope[1]], - rewards=rewards[scope[0]:scope[1]], - next_states=next_states[scope[0]:scope[1]], - terminated=terminated[scope[0]:scope[1]], - truncated=truncated[scope[0]:scope[1]], - infos=infos, - timestep=timestep, - timesteps=self.timesteps) + agent.record_transition( + states=states[scope[0] : scope[1]], + actions=actions[scope[0] : scope[1]], + rewards=rewards[scope[0] : scope[1]], + next_states=next_states[scope[0] : scope[1]], + terminated=terminated[scope[0] : scope[1]], + truncated=truncated[scope[0] : scope[1]], + infos=infos, + timestep=timestep, + timesteps=self.timesteps, + ) # log environment info if self.environment_info in infos: @@ -167,7 +177,9 @@ def eval(self) -> None: # reset env states, infos = self.env.reset() - for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout): + for timestep in tqdm.tqdm( + range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout + ): # pre-interaction for agent in self.agents: @@ -175,8 +187,12 @@ def eval(self) -> None: with torch.no_grad(): # compute actions - actions = torch.vstack([agent.act(states[scope[0]:scope[1]], timestep=timestep, timesteps=self.timesteps)[0] \ - for agent, scope in zip(self.agents, self.agents_scope)]) + actions = torch.vstack( + [ + agent.act(states[scope[0] : scope[1]], timestep=timestep, timesteps=self.timesteps)[0] + for agent, scope in zip(self.agents, self.agents_scope) + ] + ) # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) @@ -187,15 +203,17 @@ def eval(self) -> None: # write data to TensorBoard for agent, scope in zip(self.agents, self.agents_scope): - agent.record_transition(states=states[scope[0]:scope[1]], - actions=actions[scope[0]:scope[1]], - rewards=rewards[scope[0]:scope[1]], - next_states=next_states[scope[0]:scope[1]], - terminated=terminated[scope[0]:scope[1]], - truncated=truncated[scope[0]:scope[1]], - infos=infos, - timestep=timestep, - timesteps=self.timesteps) + agent.record_transition( + states=states[scope[0] : scope[1]], + actions=actions[scope[0] : scope[1]], + rewards=rewards[scope[0] : scope[1]], + next_states=next_states[scope[0] : scope[1]], + terminated=terminated[scope[0] : scope[1]], + truncated=truncated[scope[0] : scope[1]], + infos=infos, + timestep=timestep, + timesteps=self.timesteps, + ) # log environment info if self.environment_info in infos: diff --git a/skrl/trainers/torch/step.py b/skrl/trainers/torch/step.py index d4783c9d..77987598 100644 --- a/skrl/trainers/torch/step.py +++ b/skrl/trainers/torch/step.py @@ -25,11 +25,13 @@ class StepTrainer(Trainer): - def __init__(self, - env: Wrapper, - agents: Union[Agent, List[Agent]], - agents_scope: Optional[List[int]] = None, - cfg: Optional[dict] = None) -> None: + def __init__( + self, + env: Wrapper, + agents: Union[Agent, List[Agent]], + agents_scope: Optional[List[int]] = None, + cfg: Optional[dict] = None, + ) -> None: """Step-by-step trainer Train agents by controlling the training/evaluation loop step by step @@ -61,8 +63,9 @@ def __init__(self, self.states = None - def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> \ - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: + def train( + self, timestep: Optional[int] = None, timesteps: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: """Execute a training iteration This method executes the following steps once: @@ -112,8 +115,12 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) with torch.no_grad(): # compute actions - actions = torch.vstack([agent.act(self.states[scope[0]:scope[1]], timestep=timestep, timesteps=timesteps)[0] \ - for agent, scope in zip(self.agents, self.agents_scope)]) + actions = torch.vstack( + [ + agent.act(self.states[scope[0] : scope[1]], timestep=timestep, timesteps=timesteps)[0] + for agent, scope in zip(self.agents, self.agents_scope) + ] + ) # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) @@ -124,15 +131,17 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) # record the environments' transitions for agent, scope in zip(self.agents, self.agents_scope): - agent.record_transition(states=self.states[scope[0]:scope[1]], - actions=actions[scope[0]:scope[1]], - rewards=rewards[scope[0]:scope[1]], - next_states=next_states[scope[0]:scope[1]], - terminated=terminated[scope[0]:scope[1]], - truncated=truncated[scope[0]:scope[1]], - infos=infos, - timestep=timestep, - timesteps=timesteps) + agent.record_transition( + states=self.states[scope[0] : scope[1]], + actions=actions[scope[0] : scope[1]], + rewards=rewards[scope[0] : scope[1]], + next_states=next_states[scope[0] : scope[1]], + terminated=terminated[scope[0] : scope[1]], + truncated=truncated[scope[0] : scope[1]], + infos=infos, + timestep=timestep, + timesteps=timesteps, + ) # log environment info if self.environment_info in infos: @@ -154,8 +163,9 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) return next_states, rewards, terminated, truncated, infos - def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> \ - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: + def eval( + self, timestep: Optional[int] = None, timesteps: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: """Evaluate the agents sequentially This method executes the following steps in loop: @@ -202,8 +212,12 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) with torch.no_grad(): # compute actions - actions = torch.vstack([agent.act(self.states[scope[0]:scope[1]], timestep=timestep, timesteps=timesteps)[0] \ - for agent, scope in zip(self.agents, self.agents_scope)]) + actions = torch.vstack( + [ + agent.act(self.states[scope[0] : scope[1]], timestep=timestep, timesteps=timesteps)[0] + for agent, scope in zip(self.agents, self.agents_scope) + ] + ) # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) @@ -214,15 +228,17 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) # write data to TensorBoard for agent, scope in zip(self.agents, self.agents_scope): - agent.record_transition(states=self.states[scope[0]:scope[1]], - actions=actions[scope[0]:scope[1]], - rewards=rewards[scope[0]:scope[1]], - next_states=next_states[scope[0]:scope[1]], - terminated=terminated[scope[0]:scope[1]], - truncated=truncated[scope[0]:scope[1]], - infos=infos, - timestep=timestep, - timesteps=timesteps) + agent.record_transition( + states=self.states[scope[0] : scope[1]], + actions=actions[scope[0] : scope[1]], + rewards=rewards[scope[0] : scope[1]], + next_states=next_states[scope[0] : scope[1]], + terminated=terminated[scope[0] : scope[1]], + truncated=truncated[scope[0] : scope[1]], + infos=infos, + timestep=timestep, + timesteps=timesteps, + ) # log environment info if self.environment_info in infos: diff --git a/skrl/utils/__init__.py b/skrl/utils/__init__.py index a03e7847..f5c0e0be 100644 --- a/skrl/utils/__init__.py +++ b/skrl/utils/__init__.py @@ -70,7 +70,7 @@ def set_seed(seed: Optional[int] = None, deterministic: bool = False) -> int: seed = int.from_bytes(os.urandom(4), byteorder=sys.byteorder) except NotImplementedError: seed = int(time.time() * 1000) - seed %= 2 ** 31 # NumPy's legacy seeding seed must be between 0 and 2**32 - 1 + seed %= 2**31 # NumPy's legacy seeding seed must be between 0 and 2**32 - 1 seed = int(seed) # set different seeds in distributed runs diff --git a/skrl/utils/control.py b/skrl/utils/control.py index 0abbfb29..f1ba0b37 100644 --- a/skrl/utils/control.py +++ b/skrl/utils/control.py @@ -3,10 +3,9 @@ import torch -def ik(jacobian_end_effector, - current_position, current_orientation, - goal_position, goal_orientation, - damping_factor=0.05): +def ik( + jacobian_end_effector, current_position, current_orientation, goal_position, goal_orientation, damping_factor=0.05 +): """ Damped Least Squares method: https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf """ @@ -20,19 +19,28 @@ def ik(jacobian_end_effector, # solve damped least squares (dO = J.T * V) transpose = torch.transpose(jacobian_end_effector, 1, 2) - lmbda = torch.eye(6).to(jacobian_end_effector.device) * (damping_factor ** 2) - return (transpose @ torch.inverse(jacobian_end_effector @ transpose + lmbda) @ dpose) - -def osc(jacobian_end_effector, mass_matrix, - current_position, current_orientation, - goal_position, goal_orientation, - current_dof_velocities, - kp=5, kv=2): + lmbda = torch.eye(6).to(jacobian_end_effector.device) * (damping_factor**2) + return transpose @ torch.inverse(jacobian_end_effector @ transpose + lmbda) @ dpose + + +def osc( + jacobian_end_effector, + mass_matrix, + current_position, + current_orientation, + goal_position, + goal_orientation, + current_dof_velocities, + kp=5, + kv=2, +): """ https://studywolf.wordpress.com/2013/09/17/robot-control-4-operation-space-control/ """ - mass_matrix_end_effector = torch.inverse(jacobian_end_effector @ torch.inverse(mass_matrix) @ torch.transpose(jacobian_end_effector, 1, 2)) + mass_matrix_end_effector = torch.inverse( + jacobian_end_effector @ torch.inverse(mass_matrix) @ torch.transpose(jacobian_end_effector, 1, 2) + ) # compute position and orientation error position_error = kp * (goal_position - current_position) @@ -41,4 +49,7 @@ def osc(jacobian_end_effector, mass_matrix, dpose = torch.cat([position_error, orientation_error], -1) - return torch.transpose(jacobian_end_effector, 1, 2) @ mass_matrix_end_effector @ (kp * dpose).unsqueeze(-1) - kv * mass_matrix @ current_dof_velocities + return ( + torch.transpose(jacobian_end_effector, 1, 2) @ mass_matrix_end_effector @ (kp * dpose).unsqueeze(-1) + - kv * mass_matrix @ current_dof_velocities + ) diff --git a/skrl/utils/distributed/jax/launcher.py b/skrl/utils/distributed/jax/launcher.py index 5f4ae403..8be6cab4 100644 --- a/skrl/utils/distributed/jax/launcher.py +++ b/skrl/utils/distributed/jax/launcher.py @@ -18,11 +18,18 @@ def _get_args_parser() -> argparse.ArgumentParser: # worker/node size related arguments parser.add_argument("--nnodes", type=int, default=1, help="Number of nodes") parser.add_argument("--nproc-per-node", "--nproc_per_node", type=int, default=1, help="Number of workers per node") - parser.add_argument("--node-rank", "--node_rank", type=int, default=0, help="Node rank for multi-node distributed training") + parser.add_argument( + "--node-rank", "--node_rank", type=int, default=0, help="Node rank for multi-node distributed training" + ) # coordinator related arguments - parser.add_argument("--coordinator-address", "--coordinator_address", type=str, default="127.0.0.1:5000", - help="IP address and port where process 0 will start a JAX service") + parser.add_argument( + "--coordinator-address", + "--coordinator_address", + type=str, + default="127.0.0.1:5000", + help="IP address and port where process 0 will start a JAX service", + ) # positional arguments parser.add_argument("script", type=str, help="Training script path to be launched in parallel") @@ -30,7 +37,14 @@ def _get_args_parser() -> argparse.ArgumentParser: return parser -def _start_processes(cmd: Sequence[str], envs: Sequence[Mapping[str, str]], nprocs: int, daemon: bool = False, start_method: str = "spawn") -> None: + +def _start_processes( + cmd: Sequence[str], + envs: Sequence[Mapping[str, str]], + nprocs: int, + daemon: bool = False, + start_method: str = "spawn", +) -> None: """Start child processes according the specified configuration and wait for them to join :param cmd: Command to run on each child process @@ -57,6 +71,7 @@ def _start_processes(cmd: Sequence[str], envs: Sequence[Mapping[str, str]], npro for process in processes: process.join() + def _process(cmd: Sequence[str], env: Mapping[str, str]) -> None: """Run a command in the current process @@ -67,6 +82,7 @@ def _process(cmd: Sequence[str], env: Mapping[str, str]) -> None: """ subprocess.run(cmd, env=env) + def launch(): """Main entry point for launching distributed runs""" args = _get_args_parser().parse_args() diff --git a/skrl/utils/huggingface.py b/skrl/utils/huggingface.py index 6721ffb7..a12c3474 100644 --- a/skrl/utils/huggingface.py +++ b/skrl/utils/huggingface.py @@ -38,9 +38,8 @@ def download_model_from_huggingface(repo_id: str, filename: str = "agent.pt") -> raise ImportError("Hugging Face Hub package is not installed. Use 'pip install huggingface-hub' to install it") # download and cache the model from Hugging Face Hub - downloaded_model_file = huggingface_hub.hf_hub_download(repo_id=repo_id, - filename=filename, - library_name="skrl", - library_version=__version__) + downloaded_model_file = huggingface_hub.hf_hub_download( + repo_id=repo_id, filename=filename, library_name="skrl", library_version=__version__ + ) return downloaded_model_file diff --git a/skrl/utils/isaacgym_utils.py b/skrl/utils/isaacgym_utils.py index 325c4af2..6676e290 100644 --- a/skrl/utils/isaacgym_utils.py +++ b/skrl/utils/isaacgym_utils.py @@ -40,7 +40,7 @@ def __init__(self, host: str = "127.0.0.1", port: int = 5000) -> None: self._app.add_url_rule("/_route_stream", view_func=self._route_stream) self._app.add_url_rule("/_route_input_event", view_func=self._route_input_event, methods=["POST"]) - self._log = logging.getLogger('werkzeug') + self._log = logging.getLogger("werkzeug") self._log.disabled = True self._app.logger.disabled = True @@ -54,12 +54,13 @@ def __init__(self, host: str = "127.0.0.1", port: int = 5000) -> None: self._event_stream = threading.Event() # start server - self._thread = threading.Thread(target=lambda: \ - self._app.run(host=host, port=port, debug=False, use_reloader=False), daemon=True) + self._thread = threading.Thread( + target=lambda: self._app.run(host=host, port=port, debug=False, use_reloader=False), daemon=True + ) self._thread.start() print(f"\nStarting web viewer on http://{host}:{port}/\n") - def _route_index(self) -> 'flask.Response': + def _route_index(self) -> "flask.Response": """Render the web page :return: Flask response @@ -145,25 +146,28 @@ def _route_index(self) -> 'flask.Response': self._event_load.set() return flask.render_template_string(template) - def _route_stream(self) -> 'flask.Response': + def _route_stream(self) -> "flask.Response": """Stream the image to the web page :return: Flask response :rtype: flask.Response """ - return flask.Response(self._stream(), mimetype='multipart/x-mixed-replace; boundary=frame') + return flask.Response(self._stream(), mimetype="multipart/x-mixed-replace; boundary=frame") - def _route_input_event(self) -> 'flask.Response': + def _route_input_event(self) -> "flask.Response": """Handle keyboard and mouse input :return: Flask response :rtype: flask.Response """ + def q_mult(q1, q2): - return [q1[0] * q2[0] - q1[1] * q2[1] - q1[2] * q2[2] - q1[3] * q2[3], - q1[0] * q2[1] + q1[1] * q2[0] + q1[2] * q2[3] - q1[3] * q2[2], - q1[0] * q2[2] + q1[2] * q2[0] + q1[3] * q2[1] - q1[1] * q2[3], - q1[0] * q2[3] + q1[3] * q2[0] + q1[1] * q2[2] - q1[2] * q2[1]] + return [ + q1[0] * q2[0] - q1[1] * q2[1] - q1[2] * q2[2] - q1[3] * q2[3], + q1[0] * q2[1] + q1[1] * q2[0] + q1[2] * q2[3] - q1[3] * q2[2], + q1[0] * q2[2] + q1[2] * q2[0] + q1[3] * q2[1] - q1[1] * q2[3], + q1[0] * q2[3] + q1[3] * q2[0] + q1[1] * q2[2] - q1[2] * q2[1], + ] def q_conj(q): return [q[0], -q[1], -q[2], -q[3]] @@ -190,15 +194,14 @@ def p_target(p, q, a=0, b=0, c=1, d=0): key, mouse = data.get("key", None), data.get("mouse", None) dx, dy, dz = data.get("dx", None), data.get("dy", None), data.get("dz", None) - transform = self._gym.get_camera_transform(self._sim, - self._envs[self._camera_id], - self._cameras[self._camera_id]) + transform = self._gym.get_camera_transform( + self._sim, self._envs[self._camera_id], self._cameras[self._camera_id] + ) # zoom in/out if mouse == "wheel": # compute zoom vector - vector = qv_mult([transform.r.w, transform.r.x, transform.r.y, transform.r.z], - [-0.025 * dz, 0, 0]) + vector = qv_mult([transform.r.w, transform.r.x, transform.r.y, transform.r.z], [-0.025 * dz, 0, 0]) # update transform transform.p.x += vector[0] @@ -216,8 +219,10 @@ def p_target(p, q, a=0, b=0, c=1, d=0): q = q_mult(q, q_from_angle_axis(dy, [1, 0, 0])) # apply rotation - t = p_target([transform.p.x, transform.p.y, transform.p.z], - [transform.r.w, transform.r.x, transform.r.y, transform.r.z]) + t = p_target( + [transform.p.x, transform.p.y, transform.p.z], + [transform.r.w, transform.r.x, transform.r.y, transform.r.z], + ) p = qv_mult(q, [transform.p.x - t[0], transform.p.y - t[1], transform.p.z - t[2]]) q = q_mult(q, [transform.r.w, transform.r.x, transform.r.y, transform.r.z]) @@ -246,8 +251,7 @@ def p_target(p, q, a=0, b=0, c=1, d=0): # walk camera elif mouse == "middle": # compute displacement - vector = qv_mult([transform.r.w, transform.r.x, transform.r.y, transform.r.z], - [0, 0.001 * dx, 0.001 * dy]) + vector = qv_mult([transform.r.w, transform.r.x, transform.r.y, transform.r.z], [0, 0.001 * dx, 0.001 * dy]) # update transform transform.p.x += vector[0] @@ -270,9 +274,7 @@ def p_target(p, q, a=0, b=0, c=1, d=0): else: return flask.Response(status=200) - self._gym.set_camera_transform(self._cameras[self._camera_id], - self._envs[self._camera_id], - transform) + self._gym.set_camera_transform(self._cameras[self._camera_id], self._envs[self._camera_id], transform) return flask.Response(status=200) @@ -289,13 +291,14 @@ def _stream(self) -> bytes: image = imageio.imwrite("", self._image, format="JPEG") # stream image - yield (b'--frame\r\n' - b'Content-Type: image/jpeg\r\n\r\n' + image + b'\r\n') + yield (b"--frame\r\n" b"Content-Type: image/jpeg\r\n\r\n" + image + b"\r\n") self._event_stream.clear() self._notified = False - def setup(self, gym: 'isaacgym.gymapi.Gym', sim: 'isaacgym.gymapi.Sim', envs: List[int], cameras: List[int]) -> None: + def setup( + self, gym: "isaacgym.gymapi.Gym", sim: "isaacgym.gymapi.Sim", envs: List[int], cameras: List[int] + ) -> None: """Setup the web viewer :param gym: The gym @@ -312,11 +315,13 @@ def setup(self, gym: 'isaacgym.gymapi.Gym', sim: 'isaacgym.gymapi.Sim', envs: Li self._envs = envs self._cameras = cameras - def render(self, - fetch_results: bool = True, - step_graphics: bool = True, - render_all_camera_sensors: bool = True, - wait_for_page_load: bool = True) -> None: + def render( + self, + fetch_results: bool = True, + step_graphics: bool = True, + render_all_camera_sensors: bool = True, + wait_for_page_load: bool = True, + ) -> None: """Render and get the image from the current camera This function must be called after the simulation is stepped (post_physics_step). @@ -361,10 +366,9 @@ def render(self, self._gym.render_all_camera_sensors(self._sim) # get image - image = self._gym.get_camera_image(self._sim, - self._envs[self._camera_id], - self._cameras[self._camera_id], - self._camera_type) + image = self._gym.get_camera_image( + self._sim, self._envs[self._camera_id], self._cameras[self._camera_id], self._camera_type + ) if self._camera_type == gymapi.IMAGE_COLOR: self._image = image.reshape(image.shape[0], -1, 4)[..., :3] elif self._camera_type == gymapi.IMAGE_DEPTH: @@ -381,13 +385,15 @@ def render(self, self._notified = True -def ik(jacobian_end_effector: torch.Tensor, - current_position: torch.Tensor, - current_orientation: torch.Tensor, - goal_position: torch.Tensor, - goal_orientation: Optional[torch.Tensor] = None, - damping_factor: float = 0.05, - squeeze_output: bool = True) -> torch.Tensor: +def ik( + jacobian_end_effector: torch.Tensor, + current_position: torch.Tensor, + current_orientation: torch.Tensor, + goal_position: torch.Tensor, + goal_orientation: Optional[torch.Tensor] = None, + damping_factor: float = 0.05, + squeeze_output: bool = True, +) -> torch.Tensor: """ Inverse kinematics using damped least squares method @@ -414,42 +420,89 @@ def ik(jacobian_end_effector: torch.Tensor, # compute error q = torch_utils.quat_mul(goal_orientation, torch_utils.quat_conjugate(current_orientation)) - error = torch.cat([goal_position - current_position, # position error - q[:, 0:3] * torch.sign(q[:, 3]).unsqueeze(-1)], # orientation error - dim=-1).unsqueeze(-1) + error = torch.cat( + [ + goal_position - current_position, # position error + q[:, 0:3] * torch.sign(q[:, 3]).unsqueeze(-1), + ], # orientation error + dim=-1, + ).unsqueeze(-1) # solve damped least squares (dO = J.T * V) transpose = torch.transpose(jacobian_end_effector, 1, 2) - lmbda = torch.eye(6, device=jacobian_end_effector.device) * (damping_factor ** 2) + lmbda = torch.eye(6, device=jacobian_end_effector.device) * (damping_factor**2) if squeeze_output: return (transpose @ torch.inverse(jacobian_end_effector @ transpose + lmbda) @ error).squeeze(dim=2) else: return transpose @ torch.inverse(jacobian_end_effector @ transpose + lmbda) @ error + def print_arguments(args): print("") print("Arguments") for a in args.__dict__: print(f" |-- {a}: {args.__getattribute__(a)}") -def print_asset_options(asset_options: 'isaacgym.gymapi.AssetOptions', asset_name: str = ""): - attrs = ["angular_damping", "armature", "collapse_fixed_joints", "convex_decomposition_from_submeshes", - "default_dof_drive_mode", "density", "disable_gravity", "fix_base_link", "flip_visual_attachments", - "linear_damping", "max_angular_velocity", "max_linear_velocity", "mesh_normal_mode", "min_particle_mass", - "override_com", "override_inertia", "replace_cylinder_with_capsule", "tendon_limit_stiffness", "thickness", - "use_mesh_materials", "use_physx_armature", "vhacd_enabled"] # vhacd_params + +def print_asset_options(asset_options: "isaacgym.gymapi.AssetOptions", asset_name: str = ""): + attrs = [ + "angular_damping", + "armature", + "collapse_fixed_joints", + "convex_decomposition_from_submeshes", + "default_dof_drive_mode", + "density", + "disable_gravity", + "fix_base_link", + "flip_visual_attachments", + "linear_damping", + "max_angular_velocity", + "max_linear_velocity", + "mesh_normal_mode", + "min_particle_mass", + "override_com", + "override_inertia", + "replace_cylinder_with_capsule", + "tendon_limit_stiffness", + "thickness", + "use_mesh_materials", + "use_physx_armature", + "vhacd_enabled", + ] # vhacd_params print("\nAsset options{}".format(f" ({asset_name})" if asset_name else "")) for attr in attrs: print(" |-- {}: {}".format(attr, getattr(asset_options, attr) if hasattr(asset_options, attr) else "--")) # vhacd attributes if attr == "vhacd_enabled" and hasattr(asset_options, attr) and getattr(asset_options, attr): - vhacd_attrs = ["alpha", "beta", "concavity", "convex_hull_approximation", "convex_hull_downsampling", - "max_convex_hulls", "max_num_vertices_per_ch", "min_volume_per_ch", "mode", "ocl_acceleration", - "pca", "plane_downsampling", "project_hull_vertices", "resolution"] + vhacd_attrs = [ + "alpha", + "beta", + "concavity", + "convex_hull_approximation", + "convex_hull_downsampling", + "max_convex_hulls", + "max_num_vertices_per_ch", + "min_volume_per_ch", + "mode", + "ocl_acceleration", + "pca", + "plane_downsampling", + "project_hull_vertices", + "resolution", + ] print(" |-- vhacd_params:") for vhacd_attr in vhacd_attrs: - print(" | |-- {}: {}".format(vhacd_attr, getattr(asset_options.vhacd_params, vhacd_attr) \ - if hasattr(asset_options.vhacd_params, vhacd_attr) else "--")) + print( + " | |-- {}: {}".format( + vhacd_attr, + ( + getattr(asset_options.vhacd_params, vhacd_attr) + if hasattr(asset_options.vhacd_params, vhacd_attr) + else "--" + ), + ) + ) + def print_sim_components(gym, sim): print("") @@ -461,6 +514,7 @@ def print_sim_components(gym, sim): print(" |-- dof count:", gym.get_sim_dof_count(sim)) print(" |-- force sensor count:", gym.get_sim_force_sensor_count(sim)) + def print_env_components(gym, env): print("") print("Env components") @@ -469,6 +523,7 @@ def print_env_components(gym, env): print(" |-- joint count:", gym.get_env_joint_count(env)) print(" |-- dof count:", gym.get_env_dof_count(env)) + def print_actor_components(gym, env, actor): print("") print("Actor components") @@ -480,6 +535,7 @@ def print_actor_components(gym, env, actor): print(" |-- soft body count:", gym.get_actor_soft_body_count(env, actor)) print(" |-- tendon count:", gym.get_actor_tendon_count(env, actor)) + def print_dof_properties(gymapi, props): print("") print("DOF properties") @@ -498,6 +554,7 @@ def print_dof_properties(gymapi, props): print(" |-- friction:", props["friction"]) print(" |-- armature:", props["armature"]) + def print_links_and_dofs(gym, asset): link_dict = gym.get_asset_rigid_body_dict(asset) dof_dict = gym.get_asset_dof_dict(asset) diff --git a/skrl/utils/model_instantiators/jax/__init__.py b/skrl/utils/model_instantiators/jax/__init__.py index 24b3ef3b..8ce5614e 100644 --- a/skrl/utils/model_instantiators/jax/__init__.py +++ b/skrl/utils/model_instantiators/jax/__init__.py @@ -10,6 +10,7 @@ class Shape(Enum): """ Enum to select the shape of the model's inputs and outputs """ + ONE = 1 STATES = 0 OBSERVATIONS = 0 diff --git a/skrl/utils/model_instantiators/jax/categorical.py b/skrl/utils/model_instantiators/jax/categorical.py index 844b4aa2..21e33ffb 100644 --- a/skrl/utils/model_instantiators/jax/categorical.py +++ b/skrl/utils/model_instantiators/jax/categorical.py @@ -13,15 +13,17 @@ from skrl.utils.spaces.jax import unflatten_tensorized_space # noqa -def categorical_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, jax.Device]] = None, - unnormalized_log_prob: bool = True, - network: Sequence[Mapping[str, Any]] = [], - output: Union[str, Sequence[str]] = "", - return_source: bool = False, - *args, - **kwargs) -> Union[Model, str]: +def categorical_model( + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, jax.Device]] = None, + unnormalized_log_prob: bool = True, + network: Sequence[Mapping[str, Any]] = [], + output: Union[str, Sequence[str]] = "", + return_source: bool = False, + *args, + **kwargs, +) -> Union[Model, str]: """Instantiate a categorical model :param observation_space: Observation/state space or shape (default: None). @@ -96,7 +98,9 @@ def __call__(self, inputs, role): # instantiate model _locals = {} exec(template, globals(), _locals) - return _locals["CategoricalModel"](observation_space=observation_space, - action_space=action_space, - device=device, - unnormalized_log_prob=unnormalized_log_prob) + return _locals["CategoricalModel"]( + observation_space=observation_space, + action_space=action_space, + device=device, + unnormalized_log_prob=unnormalized_log_prob, + ) diff --git a/skrl/utils/model_instantiators/jax/common.py b/skrl/utils/model_instantiators/jax/common.py index 52d79e2b..6337bb3a 100644 --- a/skrl/utils/model_instantiators/jax/common.py +++ b/skrl/utils/model_instantiators/jax/common.py @@ -37,6 +37,7 @@ def _get_activation_function(activation: Union[str, None]) -> Union[str, None]: } return activations.get(activation.lower() if type(activation) is str else activation, None) + def _parse_input(source: str) -> str: """Parse a network input expression by replacing substitutions and applying operations @@ -44,6 +45,7 @@ def _parse_input(source: str) -> str: :return: Parsed network input """ + class NodeTransformer(ast.NodeTransformer): def visit_Call(self, node: ast.Call): if isinstance(node.func, ast.Name): @@ -61,13 +63,18 @@ def visit_Call(self, node: ast.Call): NodeTransformer().visit(tree) source = ast.unparse(tree) # enum substitutions - source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace("STATES_ACTIONS", "jnp.concatenate([states, taken_actions], axis=-1)") - source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace("OBSERVATIONS_ACTIONS", "jnp.concatenate([states, taken_actions], axis=-1)") + source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace( + "STATES_ACTIONS", "jnp.concatenate([states, taken_actions], axis=-1)" + ) + source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace( + "OBSERVATIONS_ACTIONS", "jnp.concatenate([states, taken_actions], axis=-1)" + ) source = source.replace("Shape.STATES", "STATES").replace("STATES", "states") source = source.replace("Shape.OBSERVATIONS", "OBSERVATIONS").replace("OBSERVATIONS", "states") source = source.replace("Shape.ACTIONS", "ACTIONS").replace("ACTIONS", "taken_actions") return source + def _parse_output(source: Union[str, Sequence[str]]) -> Tuple[Union[str, Sequence[str]], Sequence[str], int]: """Parse the network output expression by replacing substitutions and applying operations @@ -75,6 +82,7 @@ def _parse_output(source: Union[str, Sequence[str]]) -> Tuple[Union[str, Sequenc :return: Tuple with the parsed network output, generated modules and output size/shape """ + class NodeTransformer(ast.NodeTransformer): def visit_Call(self, node: ast.Call): if isinstance(node.func, ast.Name): @@ -109,6 +117,7 @@ def visit_Call(self, node: ast.Call): raise ValueError(f"Invalid or unsupported network output definition: {source}") return source, modules, size + def _generate_modules(layers: Sequence[str], activations: Union[Sequence[str], str]) -> Sequence[str]: """Generate network modules @@ -152,7 +161,7 @@ def _generate_modules(layers: Sequence[str], activations: Union[Sequence[str], s if type(kwargs) in [int, float]: kwargs = {"features": int(kwargs)} elif type(kwargs) is list: - kwargs = {k: v for k, v in zip(["features", "use_bias"][:len(kwargs)], kwargs)} + kwargs = {k: v for k, v in zip(["features", "use_bias"][: len(kwargs)], kwargs)} elif type(kwargs) is dict: if "in_features" in kwargs: del kwargs["in_features"] @@ -169,7 +178,12 @@ def _generate_modules(layers: Sequence[str], activations: Union[Sequence[str], s cls = "nn.Conv" kwargs = layer[layer_type] if type(kwargs) is list: - kwargs = {k: v for k, v in zip(["features", "kernel_size", "strides", "padding", "use_bias"][:len(kwargs)], kwargs)} + kwargs = { + k: v + for k, v in zip( + ["features", "kernel_size", "strides", "padding", "use_bias"][: len(kwargs)], kwargs + ) + } elif type(kwargs) is dict: if "in_channels" in kwargs: del kwargs["in_channels"] @@ -201,6 +215,7 @@ def _generate_modules(layers: Sequence[str], activations: Union[Sequence[str], s modules.append(activation) return modules + def get_num_units(token: Union[str, Any]) -> Union[str, Any]: """Get the number of units/features a token represent @@ -221,9 +236,10 @@ def get_num_units(token: Union[str, Any]) -> Union[str, Any]: return num_units[token_as_str] return token -def generate_containers(network: Sequence[Mapping[str, Any]], output: Union[str, Sequence[str]], - embed_output: bool = True, indent: int = -1) -> \ - Tuple[Sequence[Mapping[str, Any]], Mapping[str, Any]]: + +def generate_containers( + network: Sequence[Mapping[str, Any]], output: Union[str, Sequence[str]], embed_output: bool = True, indent: int = -1 +) -> Tuple[Sequence[Mapping[str, Any]], Mapping[str, Any]]: """Generate network containers :param network: Network definition @@ -268,6 +284,7 @@ def generate_containers(network: Sequence[Mapping[str, Any]], output: Union[str, output = {"output": output, "modules": output_modules, "size": output_size} return containers, output + def convert_deprecated_parameters(parameters: Mapping[str, Any]) -> Tuple[Mapping[str, Any], str]: """Function to convert deprecated parameters to network-output format @@ -275,8 +292,10 @@ def convert_deprecated_parameters(parameters: Mapping[str, Any]) -> Tuple[Mappin :return: Network and output definitions """ - logger.warning(f'The following parameters ({", ".join(list(parameters.keys()))}) are deprecated. ' - "See https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html") + logger.warning( + f'The following parameters ({", ".join(list(parameters.keys()))}) are deprecated. ' + "See https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html" + ) # network definition activations = parameters.get("hidden_activation", []) if type(activations) in [list, tuple] and len(set(activations)) == 1: diff --git a/skrl/utils/model_instantiators/jax/deterministic.py b/skrl/utils/model_instantiators/jax/deterministic.py index be7f9bd9..f769eb4b 100644 --- a/skrl/utils/model_instantiators/jax/deterministic.py +++ b/skrl/utils/model_instantiators/jax/deterministic.py @@ -13,15 +13,17 @@ from skrl.utils.spaces.jax import unflatten_tensorized_space # noqa -def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, jax.Device]] = None, - clip_actions: bool = False, - network: Sequence[Mapping[str, Any]] = [], - output: Union[str, Sequence[str]] = "", - return_source: bool = False, - *args, - **kwargs) -> Union[Model, str]: +def deterministic_model( + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, jax.Device]] = None, + clip_actions: bool = False, + network: Sequence[Mapping[str, Any]] = [], + output: Union[str, Sequence[str]] = "", + return_source: bool = False, + *args, + **kwargs, +) -> Union[Model, str]: """Instantiate a deterministic model :param observation_space: Observation/state space or shape (default: None). @@ -93,7 +95,6 @@ def __call__(self, inputs, role): # instantiate model _locals = {} exec(template, globals(), _locals) - return _locals["DeterministicModel"](observation_space=observation_space, - action_space=action_space, - device=device, - clip_actions=clip_actions) + return _locals["DeterministicModel"]( + observation_space=observation_space, action_space=action_space, device=device, clip_actions=clip_actions + ) diff --git a/skrl/utils/model_instantiators/jax/gaussian.py b/skrl/utils/model_instantiators/jax/gaussian.py index 865e7ff3..529be7fd 100644 --- a/skrl/utils/model_instantiators/jax/gaussian.py +++ b/skrl/utils/model_instantiators/jax/gaussian.py @@ -13,19 +13,21 @@ from skrl.utils.spaces.jax import unflatten_tensorized_space # noqa -def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, jax.Device]] = None, - clip_actions: bool = False, - clip_log_std: bool = True, - min_log_std: float = -20, - max_log_std: float = 2, - initial_log_std: float = 0, - network: Sequence[Mapping[str, Any]] = [], - output: Union[str, Sequence[str]] = "", - return_source: bool = False, - *args, - **kwargs) -> Union[Model, str]: +def gaussian_model( + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, jax.Device]] = None, + clip_actions: bool = False, + clip_log_std: bool = True, + min_log_std: float = -20, + max_log_std: float = 2, + initial_log_std: float = 0, + network: Sequence[Mapping[str, Any]] = [], + output: Union[str, Sequence[str]] = "", + return_source: bool = False, + *args, + **kwargs, +) -> Union[Model, str]: """Instantiate a Gaussian model :param observation_space: Observation/state space or shape (default: None). @@ -107,10 +109,12 @@ def __call__(self, inputs, role): # instantiate model _locals = {} exec(template, globals(), _locals) - return _locals["GaussianModel"](observation_space=observation_space, - action_space=action_space, - device=device, - clip_actions=clip_actions, - clip_log_std=clip_log_std, - min_log_std=min_log_std, - max_log_std=max_log_std) + return _locals["GaussianModel"]( + observation_space=observation_space, + action_space=action_space, + device=device, + clip_actions=clip_actions, + clip_log_std=clip_log_std, + min_log_std=min_log_std, + max_log_std=max_log_std, + ) diff --git a/skrl/utils/model_instantiators/torch/__init__.py b/skrl/utils/model_instantiators/torch/__init__.py index 1bc18267..83528042 100644 --- a/skrl/utils/model_instantiators/torch/__init__.py +++ b/skrl/utils/model_instantiators/torch/__init__.py @@ -12,6 +12,7 @@ class Shape(Enum): """ Enum to select the shape of the model's inputs and outputs """ + ONE = 1 STATES = 0 OBSERVATIONS = 0 diff --git a/skrl/utils/model_instantiators/torch/categorical.py b/skrl/utils/model_instantiators/torch/categorical.py index bb930211..cf4b7510 100644 --- a/skrl/utils/model_instantiators/torch/categorical.py +++ b/skrl/utils/model_instantiators/torch/categorical.py @@ -12,15 +12,17 @@ from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa -def categorical_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - unnormalized_log_prob: bool = True, - network: Sequence[Mapping[str, Any]] = [], - output: Union[str, Sequence[str]] = "", - return_source: bool = False, - *args, - **kwargs) -> Union[Model, str]: +def categorical_model( + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + unnormalized_log_prob: bool = True, + network: Sequence[Mapping[str, Any]] = [], + output: Union[str, Sequence[str]] = "", + return_source: bool = False, + *args, + **kwargs, +) -> Union[Model, str]: """Instantiate a categorical model :param observation_space: Observation/state space or shape (default: None). @@ -94,7 +96,9 @@ def compute(self, inputs, role=""): # instantiate model _locals = {} exec(template, globals(), _locals) - return _locals["CategoricalModel"](observation_space=observation_space, - action_space=action_space, - device=device, - unnormalized_log_prob=unnormalized_log_prob) + return _locals["CategoricalModel"]( + observation_space=observation_space, + action_space=action_space, + device=device, + unnormalized_log_prob=unnormalized_log_prob, + ) diff --git a/skrl/utils/model_instantiators/torch/common.py b/skrl/utils/model_instantiators/torch/common.py index cde885d3..f366811a 100644 --- a/skrl/utils/model_instantiators/torch/common.py +++ b/skrl/utils/model_instantiators/torch/common.py @@ -38,6 +38,7 @@ def _get_activation_function(activation: Union[str, None], as_module: bool = Tru } return activations.get(activation.lower() if type(activation) is str else activation, None) + def _parse_input(source: str) -> str: """Parse a network input expression by replacing substitutions and applying operations @@ -45,6 +46,7 @@ def _parse_input(source: str) -> str: :return: Parsed network input """ + class NodeTransformer(ast.NodeTransformer): def visit_Call(self, node: ast.Call): if isinstance(node.func, ast.Name): @@ -62,13 +64,18 @@ def visit_Call(self, node: ast.Call): NodeTransformer().visit(tree) source = ast.unparse(tree) # enum substitutions - source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace("STATES_ACTIONS", "torch.cat([states, taken_actions], dim=1)") - source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace("OBSERVATIONS_ACTIONS", "torch.cat([states, taken_actions], dim=1)") + source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace( + "STATES_ACTIONS", "torch.cat([states, taken_actions], dim=1)" + ) + source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace( + "OBSERVATIONS_ACTIONS", "torch.cat([states, taken_actions], dim=1)" + ) source = source.replace("Shape.STATES", "STATES").replace("STATES", "states") source = source.replace("Shape.OBSERVATIONS", "OBSERVATIONS").replace("OBSERVATIONS", "states") source = source.replace("Shape.ACTIONS", "ACTIONS").replace("ACTIONS", "taken_actions") return source + def _parse_output(source: Union[str, Sequence[str]]) -> Tuple[Union[str, Sequence[str]], Sequence[str], int]: """Parse the network output expression by replacing substitutions and applying operations @@ -76,6 +83,7 @@ def _parse_output(source: Union[str, Sequence[str]]) -> Tuple[Union[str, Sequenc :return: Tuple with the parsed network output, generated modules and output size/shape """ + class NodeTransformer(ast.NodeTransformer): def visit_Call(self, node: ast.Call): if isinstance(node.func, ast.Name): @@ -110,6 +118,7 @@ def visit_Call(self, node: ast.Call): raise ValueError(f"Invalid or unsupported network output definition: {source}") return source, modules, size + def _generate_modules(layers: Sequence[str], activations: Union[Sequence[str], str]) -> Sequence[str]: """Generate network modules @@ -153,7 +162,7 @@ def _generate_modules(layers: Sequence[str], activations: Union[Sequence[str], s if type(kwargs) in [int, float]: kwargs = {"out_features": int(kwargs)} elif type(kwargs) is list: - kwargs = {k: v for k, v in zip(["out_features", "bias"][:len(kwargs)], kwargs)} + kwargs = {k: v for k, v in zip(["out_features", "bias"][: len(kwargs)], kwargs)} elif type(kwargs) is dict: mapping = { "features": "out_features", @@ -172,7 +181,12 @@ def _generate_modules(layers: Sequence[str], activations: Union[Sequence[str], s cls = "nn.LazyConv2d" kwargs = layer[layer_type] if type(kwargs) is list: - kwargs = {k: v for k, v in zip(["out_channels", "kernel_size", "stride", "padding", "bias"][:len(kwargs)], kwargs)} + kwargs = { + k: v + for k, v in zip( + ["out_channels", "kernel_size", "stride", "padding", "bias"][: len(kwargs)], kwargs + ) + } elif type(kwargs) is dict: mapping = { "features": "out_channels", @@ -191,7 +205,7 @@ def _generate_modules(layers: Sequence[str], activations: Union[Sequence[str], s activation = "" # don't add activation after flatten layer kwargs = layer[layer_type] if type(kwargs) is list: - kwargs = {k: v for k, v in zip(["start_dim", "end_dim"][:len(kwargs)], kwargs)} + kwargs = {k: v for k, v in zip(["start_dim", "end_dim"][: len(kwargs)], kwargs)} elif type(kwargs) is dict: pass else: @@ -208,6 +222,7 @@ def _generate_modules(layers: Sequence[str], activations: Union[Sequence[str], s modules.append(activation) return modules + def get_num_units(token: Union[str, Any]) -> Union[str, Any]: """Get the number of units/features a token represent @@ -228,9 +243,10 @@ def get_num_units(token: Union[str, Any]) -> Union[str, Any]: return num_units[token_as_str] return token -def generate_containers(network: Sequence[Mapping[str, Any]], output: Union[str, Sequence[str]], - embed_output: bool = True, indent: int = -1) -> \ - Tuple[Sequence[Mapping[str, Any]], Mapping[str, Any]]: + +def generate_containers( + network: Sequence[Mapping[str, Any]], output: Union[str, Sequence[str]], embed_output: bool = True, indent: int = -1 +) -> Tuple[Sequence[Mapping[str, Any]], Mapping[str, Any]]: """Generate network containers :param network: Network definition @@ -275,6 +291,7 @@ def generate_containers(network: Sequence[Mapping[str, Any]], output: Union[str, output = {"output": output, "modules": output_modules, "size": output_size} return containers, output + def convert_deprecated_parameters(parameters: Mapping[str, Any]) -> Tuple[Mapping[str, Any], str]: """Function to convert deprecated parameters to network-output format @@ -282,8 +299,10 @@ def convert_deprecated_parameters(parameters: Mapping[str, Any]) -> Tuple[Mappin :return: Network and output definitions """ - logger.warning(f'The following parameters ({", ".join(list(parameters.keys()))}) are deprecated. ' - "See https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html") + logger.warning( + f'The following parameters ({", ".join(list(parameters.keys()))}) are deprecated. ' + "See https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html" + ) # network definition activations = parameters.get("hidden_activation", []) if type(activations) in [list, tuple] and len(set(activations)) == 1: diff --git a/skrl/utils/model_instantiators/torch/deterministic.py b/skrl/utils/model_instantiators/torch/deterministic.py index 440223cd..fad41af5 100644 --- a/skrl/utils/model_instantiators/torch/deterministic.py +++ b/skrl/utils/model_instantiators/torch/deterministic.py @@ -12,15 +12,17 @@ from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa -def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - clip_actions: bool = False, - network: Sequence[Mapping[str, Any]] = [], - output: Union[str, Sequence[str]] = "", - return_source: bool = False, - *args, - **kwargs) -> Union[Model, str]: +def deterministic_model( + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + clip_actions: bool = False, + network: Sequence[Mapping[str, Any]] = [], + output: Union[str, Sequence[str]] = "", + return_source: bool = False, + *args, + **kwargs, +) -> Union[Model, str]: """Instantiate a deterministic model :param observation_space: Observation/state space or shape (default: None). @@ -91,7 +93,6 @@ def compute(self, inputs, role=""): # instantiate model _locals = {} exec(template, globals(), _locals) - return _locals["DeterministicModel"](observation_space=observation_space, - action_space=action_space, - device=device, - clip_actions=clip_actions) + return _locals["DeterministicModel"]( + observation_space=observation_space, action_space=action_space, device=device, clip_actions=clip_actions + ) diff --git a/skrl/utils/model_instantiators/torch/gaussian.py b/skrl/utils/model_instantiators/torch/gaussian.py index b806cab0..b37cdefc 100644 --- a/skrl/utils/model_instantiators/torch/gaussian.py +++ b/skrl/utils/model_instantiators/torch/gaussian.py @@ -12,19 +12,21 @@ from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa -def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - clip_actions: bool = False, - clip_log_std: bool = True, - min_log_std: float = -20, - max_log_std: float = 2, - initial_log_std: float = 0, - network: Sequence[Mapping[str, Any]] = [], - output: Union[str, Sequence[str]] = "", - return_source: bool = False, - *args, - **kwargs) -> Union[Model, str]: +def gaussian_model( + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + clip_actions: bool = False, + clip_log_std: bool = True, + min_log_std: float = -20, + max_log_std: float = 2, + initial_log_std: float = 0, + network: Sequence[Mapping[str, Any]] = [], + output: Union[str, Sequence[str]] = "", + return_source: bool = False, + *args, + **kwargs, +) -> Union[Model, str]: """Instantiate a Gaussian model :param observation_space: Observation/state space or shape (default: None). @@ -105,10 +107,12 @@ def compute(self, inputs, role=""): # instantiate model _locals = {} exec(template, globals(), _locals) - return _locals["GaussianModel"](observation_space=observation_space, - action_space=action_space, - device=device, - clip_actions=clip_actions, - clip_log_std=clip_log_std, - min_log_std=min_log_std, - max_log_std=max_log_std) + return _locals["GaussianModel"]( + observation_space=observation_space, + action_space=action_space, + device=device, + clip_actions=clip_actions, + clip_log_std=clip_log_std, + min_log_std=min_log_std, + max_log_std=max_log_std, + ) diff --git a/skrl/utils/model_instantiators/torch/multivariate_gaussian.py b/skrl/utils/model_instantiators/torch/multivariate_gaussian.py index b7172cc9..41b62300 100644 --- a/skrl/utils/model_instantiators/torch/multivariate_gaussian.py +++ b/skrl/utils/model_instantiators/torch/multivariate_gaussian.py @@ -12,19 +12,21 @@ from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa -def multivariate_gaussian_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - clip_actions: bool = False, - clip_log_std: bool = True, - min_log_std: float = -20, - max_log_std: float = 2, - initial_log_std: float = 0, - network: Sequence[Mapping[str, Any]] = [], - output: Union[str, Sequence[str]] = "", - return_source: bool = False, - *args, - **kwargs) -> Union[Model, str]: +def multivariate_gaussian_model( + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + clip_actions: bool = False, + clip_log_std: bool = True, + min_log_std: float = -20, + max_log_std: float = 2, + initial_log_std: float = 0, + network: Sequence[Mapping[str, Any]] = [], + output: Union[str, Sequence[str]] = "", + return_source: bool = False, + *args, + **kwargs, +) -> Union[Model, str]: """Instantiate a multivariate Gaussian model :param observation_space: Observation/state space or shape (default: None). @@ -105,10 +107,12 @@ def compute(self, inputs, role=""): # instantiate model _locals = {} exec(template, globals(), _locals) - return _locals["MultivariateGaussianModel"](observation_space=observation_space, - action_space=action_space, - device=device, - clip_actions=clip_actions, - clip_log_std=clip_log_std, - min_log_std=min_log_std, - max_log_std=max_log_std) + return _locals["MultivariateGaussianModel"]( + observation_space=observation_space, + action_space=action_space, + device=device, + clip_actions=clip_actions, + clip_log_std=clip_log_std, + min_log_std=min_log_std, + max_log_std=max_log_std, + ) diff --git a/skrl/utils/model_instantiators/torch/shared.py b/skrl/utils/model_instantiators/torch/shared.py index a30b2efa..f8809b11 100644 --- a/skrl/utils/model_instantiators/torch/shared.py +++ b/skrl/utils/model_instantiators/torch/shared.py @@ -12,14 +12,16 @@ from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa -def shared_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, - device: Optional[Union[str, torch.device]] = None, - structure: str = "", - roles: Sequence[str] = [], - parameters: Sequence[Mapping[str, Any]] = [], - single_forward_pass: bool = True, - return_source: bool = False) -> Union[Model, str]: +def shared_model( + observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, + device: Optional[Union[str, torch.device]] = None, + structure: str = "", + roles: Sequence[str] = [], + parameters: Sequence[Mapping[str, Any]] = [], + single_forward_pass: bool = True, + return_source: bool = False, +) -> Union[Model, str]: """Instantiate a shared model :param observation_space: Observation/state space or shape (default: None). @@ -52,15 +54,26 @@ def shared_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Sp parameters[0]["network"], parameters[0]["output"] = convert_deprecated_parameters(parameters[0]) parameters[1]["network"], parameters[1]["output"] = convert_deprecated_parameters(parameters[1]) # delete deprecated parameters - for parameter in ["input_shape", "hiddens", "hidden_activation", "output_shape", "output_activation", "output_scale"]: + for parameter in [ + "input_shape", + "hiddens", + "hidden_activation", + "output_shape", + "output_activation", + "output_scale", + ]: if parameter in parameters[0]: del parameters[0][parameter] if parameter in parameters[1]: del parameters[1][parameter] # parse model definitions - containers_gaussian, output_gaussian = generate_containers(parameters[0]["network"], parameters[0]["output"], embed_output=False, indent=1) - containers_deterministic, output_deterministic = generate_containers(parameters[1]["network"], parameters[1]["output"], embed_output=False, indent=1) + containers_gaussian, output_gaussian = generate_containers( + parameters[0]["network"], parameters[0]["output"], embed_output=False, indent=1 + ) + containers_deterministic, output_deterministic = generate_containers( + parameters[1]["network"], parameters[1]["output"], embed_output=False, indent=1 + ) # network definitions networks_common = [] @@ -68,7 +81,9 @@ def shared_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Sp for container in containers_gaussian: networks_common.append(f'self.{container["name"]}_container = {container["sequential"]}') forward_common.append(f'{container["name"]} = self.{container["name"]}_container({container["input"]})') - forward_common.insert(0, 'taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions"))') + forward_common.insert( + 0, 'taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions"))' + ) forward_common.insert(0, 'states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))') # process output @@ -86,7 +101,9 @@ def shared_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Sp forward_deterministic = [] if output_deterministic["modules"]: networks_deterministic.append(f'self.{roles[1]}_layer = {output_deterministic["modules"][0]}') - forward_deterministic.append(f'output = self.{roles[1]}_layer({"shared_output" if single_forward_pass else container["name"]})') + forward_deterministic.append( + f'output = self.{roles[1]}_layer({"shared_output" if single_forward_pass else container["name"]})' + ) if output_deterministic["output"]: forward_deterministic.append(f'output = {output_deterministic["output"]}') else: @@ -98,14 +115,19 @@ def shared_model(observation_space: Optional[Union[int, Tuple[int], gymnasium.Sp networks_deterministic = textwrap.indent("\n".join(networks_deterministic), prefix=" " * 8)[8:] if single_forward_pass: - forward_deterministic = [ - "if self._shared_output is None:", - ] + [" " + item for item in forward_common] + [ - f' shared_output = {container["name"]}', - "else:", - " shared_output = self._shared_output", - "self._shared_output = None", - ] + forward_deterministic + forward_deterministic = ( + [ + "if self._shared_output is None:", + ] + + [" " + item for item in forward_common] + + [ + f' shared_output = {container["name"]}', + "else:", + " shared_output = self._shared_output", + "self._shared_output = None", + ] + + forward_deterministic + ) forward_common.append(f'self._shared_output = {container["name"]}') forward_common = textwrap.indent("\n".join(forward_common), prefix=" " * 12)[12:] else: @@ -137,7 +159,7 @@ def act(self, inputs, role): return DeterministicMixin.act(self, inputs, role) """ if single_forward_pass: - template +=f""" + template += f""" def compute(self, inputs, role=""): if role == "{roles[0]}": {forward_common} @@ -148,7 +170,7 @@ def compute(self, inputs, role=""): return output, {{}} """ else: - template +=f""" + template += f""" def compute(self, inputs, role=""): {forward_common} if role == "{roles[0]}": @@ -165,6 +187,6 @@ def compute(self, inputs, role=""): # instantiate model _locals = {} exec(template, globals(), _locals) - return _locals["GaussianDeterministicModel"](observation_space=observation_space, - action_space=action_space, - device=device) + return _locals["GaussianDeterministicModel"]( + observation_space=observation_space, action_space=action_space, device=device + ) diff --git a/skrl/utils/omniverse_isaacgym_utils.py b/skrl/utils/omniverse_isaacgym_utils.py index fa86f6b1..f75c08b0 100644 --- a/skrl/utils/omniverse_isaacgym_utils.py +++ b/skrl/utils/omniverse_isaacgym_utils.py @@ -28,11 +28,13 @@ def _np_quat_mul(a, b): return np.stack([x, y, z, w], axis=-1).reshape(shape) + def _np_quat_conjugate(a): shape = a.shape a = a.reshape(-1, 4) return np.concatenate((-a[:, :3], a[:, -1:]), axis=-1).reshape(shape) + def _torch_quat_mul(a, b): assert a.shape == b.shape shape = a.shape @@ -53,19 +55,23 @@ def _torch_quat_mul(a, b): return torch.stack([w, x, y, z], dim=-1).view(shape) + def _torch_quat_conjugate(a): # wxyz shape = a.shape a = a.reshape(-1, 4) return torch.cat((a[:, :1], -a[:, 1:]), dim=-1).view(shape) -def ik(jacobian_end_effector: torch.Tensor, - current_position: torch.Tensor, - current_orientation: torch.Tensor, - goal_position: torch.Tensor, - goal_orientation: Optional[torch.Tensor] = None, - method: str = "damped least-squares", - method_cfg: Mapping[str, float] = {"scale": 1, "damping": 0.05, "min_singular_value": 1e-5}, - squeeze_output: bool = True,) -> torch.Tensor: + +def ik( + jacobian_end_effector: torch.Tensor, + current_position: torch.Tensor, + current_orientation: torch.Tensor, + goal_position: torch.Tensor, + goal_orientation: Optional[torch.Tensor] = None, + method: str = "damped least-squares", + method_cfg: Mapping[str, float] = {"scale": 1, "damping": 0.05, "min_singular_value": 1e-5}, + squeeze_output: bool = True, +) -> torch.Tensor: """Differential inverse kinematics :param jacobian_end_effector: End effector's jacobian @@ -109,9 +115,13 @@ def ik(jacobian_end_effector: torch.Tensor, if isinstance(jacobian_end_effector, torch.Tensor): # compute error q = _torch_quat_mul(goal_orientation, _torch_quat_conjugate(current_orientation)) - error = torch.cat([goal_position - current_position, # position error - q[:, 1:] * torch.sign(q[:, 0]).unsqueeze(-1)], # orientation error - dim=-1).unsqueeze(-1) + error = torch.cat( + [ + goal_position - current_position, # position error + q[:, 1:] * torch.sign(q[:, 0]).unsqueeze(-1), + ], # orientation error + dim=-1, + ).unsqueeze(-1) scale = method_cfg.get("scale", 1.0) @@ -143,9 +153,11 @@ def ik(jacobian_end_effector: torch.Tensor, elif method == "damped least-squares": damping = method_cfg.get("damping", 0.05) transpose = torch.transpose(jacobian_end_effector, 1, 2) - lmbda = torch.eye(jacobian_end_effector.shape[1], device=jacobian_end_effector.device) * (damping ** 2) + lmbda = torch.eye(jacobian_end_effector.shape[1], device=jacobian_end_effector.device) * (damping**2) if squeeze_output: - return (scale * transpose @ torch.inverse(jacobian_end_effector @ transpose + lmbda) @ error).squeeze(dim=2) + return (scale * transpose @ torch.inverse(jacobian_end_effector @ transpose + lmbda) @ error).squeeze( + dim=2 + ) else: return scale * transpose @ torch.inverse(jacobian_end_effector @ transpose + lmbda) @ error else: @@ -156,21 +168,22 @@ def ik(jacobian_end_effector: torch.Tensor, else: # compute error q = _np_quat_mul(goal_orientation, _np_quat_conjugate(current_orientation)) - error = np.concatenate([goal_position - current_position, # position error - q[:, 0:3] * np.sign(q[:, 3])]) # orientation error + error = np.concatenate( + [goal_position - current_position, q[:, 0:3] * np.sign(q[:, 3])] # position error + ) # orientation error # solve damped least squares (dO = J.T * V) transpose = np.transpose(jacobian_end_effector, 1, 2) lmbda = np.eye(6) * (method_cfg.get("damping", 0.05) ** 2) if squeeze_output: - return (transpose @ np.linalg.inv(jacobian_end_effector @ transpose + lmbda) @ error) + return transpose @ np.linalg.inv(jacobian_end_effector @ transpose + lmbda) @ error else: return transpose @ np.linalg.inv(jacobian_end_effector @ transpose + lmbda) @ error -def get_env_instance(headless: bool = True, - enable_livestream: bool = False, - enable_viewport: bool = False, - multi_threaded: bool = False) -> "omni.isaac.gym.vec_env.VecEnvBase": + +def get_env_instance( + headless: bool = True, enable_livestream: bool = False, enable_viewport: bool = False, multi_threaded: bool = False +) -> "omni.isaac.gym.vec_env.VecEnvBase": """ Instantiate a VecEnvBase-based object compatible with OmniIsaacGymEnvs @@ -257,7 +270,9 @@ def get_env_instance(headless: bool = True, class _OmniIsaacGymVecEnv(VecEnvBase): def step(self, actions): - actions = torch.clamp(actions, -self._task.clip_actions, self._task.clip_actions).to(self._task.device).clone() + actions = ( + torch.clamp(actions, -self._task.clip_actions, self._task.clip_actions).to(self._task.device).clone() + ) self._task.pre_physics_step(actions) for _ in range(self._task.control_frequency_inv): @@ -266,8 +281,16 @@ def step(self, actions): observations, rewards, dones, info = self._task.post_physics_step() - return {"obs": torch.clamp(observations, -self._task.clip_obs, self._task.clip_obs).to(self._task.rl_device).clone()}, \ - rewards.to(self._task.rl_device).clone(), dones.to(self._task.rl_device).clone(), info.copy() + return ( + { + "obs": torch.clamp(observations, -self._task.clip_obs, self._task.clip_obs) + .to(self._task.rl_device) + .clone() + }, + rewards.to(self._task.rl_device).clone(), + dones.to(self._task.rl_device).clone(), + info.copy(), + ) def reset(self): self._task.reset() @@ -292,7 +315,9 @@ def run(self, trainer=None): super().run(_OmniIsaacGymTrainerMT() if trainer is None else trainer) def _parse_data(self, data): - self._observations = torch.clamp(data["obs"], -self._task.clip_obs, self._task.clip_obs).to(self._task.rl_device).clone() + self._observations = ( + torch.clamp(data["obs"], -self._task.clip_obs, self._task.clip_obs).to(self._task.rl_device).clone() + ) self._rewards = data["rew"].to(self._task.rl_device).clone() self._dones = data["reset"].to(self._task.rl_device).clone() self._info = data["extras"].copy() @@ -320,13 +345,17 @@ def close(self): if multi_threaded: try: - return _OmniIsaacGymVecEnvMT(headless=headless, enable_livestream=enable_livestream, enable_viewport=enable_viewport) + return _OmniIsaacGymVecEnvMT( + headless=headless, enable_livestream=enable_livestream, enable_viewport=enable_viewport + ) except TypeError: logger.warning("Using an older version of Isaac Sim (2022.2.0 or earlier)") return _OmniIsaacGymVecEnvMT(headless=headless) # Isaac Sim 2022.2.0 and earlier else: try: - return _OmniIsaacGymVecEnv(headless=headless, enable_livestream=enable_livestream, enable_viewport=enable_viewport) + return _OmniIsaacGymVecEnv( + headless=headless, enable_livestream=enable_livestream, enable_viewport=enable_viewport + ) except TypeError: logger.warning("Using an older version of Isaac Sim (2022.2.0 or earlier)") return _OmniIsaacGymVecEnv(headless=headless) # Isaac Sim 2022.2.0 and earlier diff --git a/skrl/utils/postprocessing.py b/skrl/utils/postprocessing.py index 3a05dd72..238b1912 100644 --- a/skrl/utils/postprocessing.py +++ b/skrl/utils/postprocessing.py @@ -10,7 +10,7 @@ from skrl import logger -class MemoryFileIterator(): +class MemoryFileIterator: def __init__(self, pathname: str) -> None: """Python iterator for loading data from exported memories @@ -39,7 +39,7 @@ def __init__(self, pathname: str) -> None: self.n = 0 self.file_paths = sorted(glob.glob(pathname)) - def __iter__(self) -> 'MemoryFileIterator': + def __iter__(self) -> "MemoryFileIterator": """Return self to make iterable""" return self @@ -80,6 +80,7 @@ def _format_torch(self) -> Tuple[str, dict]: :rtype: tuple """ import torch + filename = os.path.basename(self.file_paths[self.n]) data = torch.load(self.file_paths[self.n]) @@ -94,7 +95,7 @@ def _format_csv(self) -> Tuple[str, dict]: """ filename = os.path.basename(self.file_paths[self.n]) - with open(self.file_paths[self.n], 'r') as f: + with open(self.file_paths[self.n], "r") as f: reader = csv.reader(f) # parse header @@ -115,14 +116,18 @@ def _format_csv(self) -> Tuple[str, dict]: data = {name: [] for name in names} for row in reader: for name, index in zip(names, indexes): - data[name].append([float(item) if item not in ["True", "False"] else bool(item) \ - for item in row[index[0]:index[1]]]) + data[name].append( + [ + float(item) if item not in ["True", "False"] else bool(item) + for item in row[index[0] : index[1]] + ] + ) self.n += 1 return filename, data -class TensorboardFileIterator(): +class TensorboardFileIterator: def __init__(self, pathname: str, tags: Union[str, List[str]]) -> None: """Python iterator for loading data from Tensorboard files @@ -141,7 +146,7 @@ def __init__(self, pathname: str, tags: Union[str, List[str]]) -> None: self.file_paths = sorted(glob.glob(pathname)) self.tags = [tags] if isinstance(tags, str) else tags - def __iter__(self) -> 'TensorboardFileIterator': + def __iter__(self) -> "TensorboardFileIterator": """Return self to make iterable""" return self diff --git a/skrl/utils/runner/jax/runner.py b/skrl/utils/runner/jax/runner.py index 27ce4f7b..a7291a69 100644 --- a/skrl/utils/runner/jax/runner.py +++ b/skrl/utils/runner/jax/runner.py @@ -56,14 +56,12 @@ def __init__(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, @property def trainer(self) -> Trainer: - """Trainer instance - """ + """Trainer instance""" return self._trainer @property def agent(self) -> Agent: - """Agent instance - """ + """Agent instance""" return self._agent @staticmethod @@ -132,7 +130,9 @@ def update_dict(d): return update_dict(copy.deepcopy(cfg)) - def _generate_models(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, Any]) -> Mapping[str, Mapping[str, Model]]: + def _generate_models( + self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, Any] + ) -> Mapping[str, Mapping[str, Model]]: """Generate model instances according to the environment specification and the given config :param env: Wrapped environment @@ -170,7 +170,9 @@ def _generate_models(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappi del _cfg["models"]["policy"]["class"] except KeyError: model_class = self._class("GaussianMixin") - logger.warning("No 'class' field defined in 'models:policy' cfg. 'GaussianMixin' will be used as default") + logger.warning( + "No 'class' field defined in 'models:policy' cfg. 'GaussianMixin' will be used as default" + ) # print model source source = model_class( observation_space=observation_spaces[agent_id], @@ -195,7 +197,9 @@ def _generate_models(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappi del _cfg["models"]["value"]["class"] except KeyError: model_class = self._class("DeterministicMixin") - logger.warning("No 'class' field defined in 'models:value' cfg. 'DeterministicMixin' will be used as default") + logger.warning( + "No 'class' field defined in 'models:value' cfg. 'DeterministicMixin' will be used as default" + ) # print model source source = model_class( observation_space=(state_spaces if agent_class in [MAPPO] else observation_spaces)[agent_id], @@ -220,11 +224,15 @@ def _generate_models(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappi try: del _cfg["models"]["policy"]["class"] except KeyError: - logger.warning("No 'class' field defined in 'models:policy' cfg. 'GaussianMixin' will be used as default") + logger.warning( + "No 'class' field defined in 'models:policy' cfg. 'GaussianMixin' will be used as default" + ) try: del _cfg["models"]["value"]["class"] except KeyError: - logger.warning("No 'class' field defined in 'models:value' cfg. 'DeterministicMixin' will be used as default") + logger.warning( + "No 'class' field defined in 'models:value' cfg. 'DeterministicMixin' will be used as default" + ) model_class = self._class("Shared") # print model source source = model_class( @@ -263,7 +271,12 @@ def _generate_models(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappi return models - def _generate_agent(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, Any], models: Mapping[str, Mapping[str, Model]]) -> Agent: + def _generate_agent( + self, + env: Union[Wrapper, MultiAgentEnvWrapper], + cfg: Mapping[str, Any], + models: Mapping[str, Mapping[str, Model]], + ) -> Agent: """Generate agent instance according to the environment specification and the given config and models :param env: Wrapped environment @@ -282,7 +295,9 @@ def _generate_agent(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappin # check for memory configuration (backward compatibility) if not "memory" in cfg: - logger.warning("Deprecation warning: No 'memory' field defined in cfg. Using the default generated configuration") + logger.warning( + "Deprecation warning: No 'memory' field defined in cfg. Using the default generated configuration" + ) cfg["memory"] = {"class": "RandomMemory", "memory_size": -1} # get memory class and remove 'class' field try: @@ -322,10 +337,9 @@ def _generate_agent(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappin elif agent_class in [IPPO]: agent_cfg = IPPO_DEFAULT_CONFIG.copy() agent_cfg.update(self._process_cfg(cfg["agent"])) - agent_cfg["state_preprocessor_kwargs"].update({ - agent_id: {"size": observation_spaces[agent_id], "device": device} - for agent_id in possible_agents - }) + agent_cfg["state_preprocessor_kwargs"].update( + {agent_id: {"size": observation_spaces[agent_id], "device": device} for agent_id in possible_agents} + ) agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": device}) agent_kwargs = { "models": models, @@ -337,10 +351,9 @@ def _generate_agent(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappin elif agent_class in [MAPPO]: agent_cfg = MAPPO_DEFAULT_CONFIG.copy() agent_cfg.update(self._process_cfg(cfg["agent"])) - agent_cfg["state_preprocessor_kwargs"].update({ - agent_id: {"size": observation_spaces[agent_id], "device": device} - for agent_id in possible_agents - }) + agent_cfg["state_preprocessor_kwargs"].update( + {agent_id: {"size": observation_spaces[agent_id], "device": device} for agent_id in possible_agents} + ) agent_cfg["shared_state_preprocessor_kwargs"].update( {agent_id: {"size": state_spaces[agent_id], "device": device} for agent_id in possible_agents} ) @@ -355,7 +368,9 @@ def _generate_agent(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappin } return agent_class(cfg=agent_cfg, device=device, **agent_kwargs) - def _generate_trainer(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, Any], agent: Agent) -> Trainer: + def _generate_trainer( + self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, Any], agent: Agent + ) -> Trainer: """Generate trainer instance according to the environment specification and the given config and agent :param env: Wrapped environment diff --git a/skrl/utils/runner/torch/runner.py b/skrl/utils/runner/torch/runner.py index 5e4f6b52..d4f3d367 100644 --- a/skrl/utils/runner/torch/runner.py +++ b/skrl/utils/runner/torch/runner.py @@ -56,14 +56,12 @@ def __init__(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, @property def trainer(self) -> Trainer: - """Trainer instance - """ + """Trainer instance""" return self._trainer @property def agent(self) -> Agent: - """Agent instance - """ + """Agent instance""" return self._agent @staticmethod @@ -132,7 +130,9 @@ def update_dict(d): return update_dict(copy.deepcopy(cfg)) - def _generate_models(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, Any]) -> Mapping[str, Mapping[str, Model]]: + def _generate_models( + self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, Any] + ) -> Mapping[str, Mapping[str, Model]]: """Generate model instances according to the environment specification and the given config :param env: Wrapped environment @@ -167,7 +167,9 @@ def _generate_models(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappi del _cfg["models"]["policy"]["class"] except KeyError: model_class = self._class("GaussianMixin") - logger.warning("No 'class' field defined in 'models:policy' cfg. 'GaussianMixin' will be used as default") + logger.warning( + "No 'class' field defined in 'models:policy' cfg. 'GaussianMixin' will be used as default" + ) # print model source source = model_class( observation_space=observation_spaces[agent_id], @@ -192,7 +194,9 @@ def _generate_models(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappi del _cfg["models"]["value"]["class"] except KeyError: model_class = self._class("DeterministicMixin") - logger.warning("No 'class' field defined in 'models:value' cfg. 'DeterministicMixin' will be used as default") + logger.warning( + "No 'class' field defined in 'models:value' cfg. 'DeterministicMixin' will be used as default" + ) # print model source source = model_class( observation_space=(state_spaces if agent_class in [MAPPO] else observation_spaces)[agent_id], @@ -217,11 +221,15 @@ def _generate_models(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappi try: del _cfg["models"]["policy"]["class"] except KeyError: - logger.warning("No 'class' field defined in 'models:policy' cfg. 'GaussianMixin' will be used as default") + logger.warning( + "No 'class' field defined in 'models:policy' cfg. 'GaussianMixin' will be used as default" + ) try: del _cfg["models"]["value"]["class"] except KeyError: - logger.warning("No 'class' field defined in 'models:value' cfg. 'DeterministicMixin' will be used as default") + logger.warning( + "No 'class' field defined in 'models:value' cfg. 'DeterministicMixin' will be used as default" + ) model_class = self._class("Shared") # print model source source = model_class( @@ -255,7 +263,12 @@ def _generate_models(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappi return models - def _generate_agent(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, Any], models: Mapping[str, Mapping[str, Model]]) -> Agent: + def _generate_agent( + self, + env: Union[Wrapper, MultiAgentEnvWrapper], + cfg: Mapping[str, Any], + models: Mapping[str, Mapping[str, Model]], + ) -> Agent: """Generate agent instance according to the environment specification and the given config and models :param env: Wrapped environment @@ -274,7 +287,9 @@ def _generate_agent(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappin # check for memory configuration (backward compatibility) if not "memory" in cfg: - logger.warning("Deprecation warning: No 'memory' field defined in cfg. Using the default generated configuration") + logger.warning( + "Deprecation warning: No 'memory' field defined in cfg. Using the default generated configuration" + ) cfg["memory"] = {"class": "RandomMemory", "memory_size": -1} # get memory class and remove 'class' field try: @@ -314,10 +329,9 @@ def _generate_agent(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappin elif agent_class in [IPPO]: agent_cfg = IPPO_DEFAULT_CONFIG.copy() agent_cfg.update(self._process_cfg(cfg["agent"])) - agent_cfg["state_preprocessor_kwargs"].update({ - agent_id: {"size": observation_spaces[agent_id], "device": device} - for agent_id in possible_agents - }) + agent_cfg["state_preprocessor_kwargs"].update( + {agent_id: {"size": observation_spaces[agent_id], "device": device} for agent_id in possible_agents} + ) agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": device}) agent_kwargs = { "models": models, @@ -329,10 +343,9 @@ def _generate_agent(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappin elif agent_class in [MAPPO]: agent_cfg = MAPPO_DEFAULT_CONFIG.copy() agent_cfg.update(self._process_cfg(cfg["agent"])) - agent_cfg["state_preprocessor_kwargs"].update({ - agent_id: {"size": observation_spaces[agent_id], "device": device} - for agent_id in possible_agents - }) + agent_cfg["state_preprocessor_kwargs"].update( + {agent_id: {"size": observation_spaces[agent_id], "device": device} for agent_id in possible_agents} + ) agent_cfg["shared_state_preprocessor_kwargs"].update( {agent_id: {"size": state_spaces[agent_id], "device": device} for agent_id in possible_agents} ) @@ -347,7 +360,9 @@ def _generate_agent(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappin } return agent_class(cfg=agent_cfg, device=device, **agent_kwargs) - def _generate_trainer(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, Any], agent: Agent) -> Trainer: + def _generate_trainer( + self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, Any], agent: Agent + ) -> Trainer: """Generate trainer instance according to the environment specification and the given config and agent :param env: Wrapped environment diff --git a/skrl/utils/spaces/jax/__init__.py b/skrl/utils/spaces/jax/__init__.py index 9e8de363..e35e1b82 100644 --- a/skrl/utils/spaces/jax/__init__.py +++ b/skrl/utils/spaces/jax/__init__.py @@ -5,5 +5,5 @@ sample_space, tensorize_space, unflatten_tensorized_space, - untensorize_space + untensorize_space, ) diff --git a/skrl/utils/spaces/jax/spaces.py b/skrl/utils/spaces/jax/spaces.py index 663d1fb7..a6ed0a7a 100644 --- a/skrl/utils/spaces/jax/spaces.py +++ b/skrl/utils/spaces/jax/spaces.py @@ -37,7 +37,10 @@ def convert_gym_space(space: "gym.Space", squeeze_batch_dimension: bool = False) return spaces.Dict(spaces={k: convert_gym_space(v) for k, v in space.spaces.items()}) raise ValueError(f"Unsupported space ({space})") -def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, jax.Device]] = None, _jax: bool = True) -> Any: + +def tensorize_space( + space: spaces.Space, x: Any, device: Optional[Union[str, jax.Device]] = None, _jax: bool = True +) -> Any: """Convert the sample/value items of a given gymnasium space to JAX array. :param space: Gymnasium space. @@ -100,6 +103,7 @@ def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, jax return tuple([tensorize_space(s, _x, device) for s, _x in zip(space, x)]) raise ValueError(f"Unsupported space ({space})") + def untensorize_space(space: spaces.Space, x: Any, squeeze_batch_dimension: bool = True) -> Any: """Convert a tensorized space to a gymnasium space with expected sample/value item types. @@ -160,6 +164,7 @@ def untensorize_space(space: spaces.Space, x: Any, squeeze_batch_dimension: bool return tuple([untensorize_space(s, _x, squeeze_batch_dimension) for s, _x in zip(space, x)]) raise ValueError(f"Unsupported space ({space})") + def flatten_tensorized_space(x: Any, _jax: bool = True) -> jax.Array: """Flatten a tensorized space. @@ -188,6 +193,7 @@ def flatten_tensorized_space(x: Any, _jax: bool = True) -> jax.Array: return np.concatenate([flatten_tensorized_space(_x) for _x in x], axis=-1) raise ValueError(f"Unsupported sample/value type ({type(x)})") + def unflatten_tensorized_space(space: Union[spaces.Space, Sequence[int], int], x: jax.Array) -> Any: """Unflatten a tensor to create a tensorized space. @@ -231,6 +237,7 @@ def unflatten_tensorized_space(space: Union[spaces.Space, Sequence[int], int], x return output raise ValueError(f"Unsupported space ({space})") + def compute_space_size(space: Union[spaces.Space, Sequence[int], int], occupied_size: bool = False) -> int: """Get the size (number of elements) of a space. @@ -264,7 +271,8 @@ def compute_space_size(space: Union[spaces.Space, Sequence[int], int], occupied_ # gymnasium computation return gymnasium.spaces.flatdim(space) -def sample_space(space: spaces.Space, batch_size: int = 1, backend: str = Literal["numpy", "jax"], device = None) -> Any: + +def sample_space(space: spaces.Space, batch_size: int = 1, backend: str = Literal["numpy", "jax"], device=None) -> Any: """Generates a random sample from the specified space. :param space: Gymnasium space. diff --git a/skrl/utils/spaces/torch/__init__.py b/skrl/utils/spaces/torch/__init__.py index 62413382..93eab527 100644 --- a/skrl/utils/spaces/torch/__init__.py +++ b/skrl/utils/spaces/torch/__init__.py @@ -5,5 +5,5 @@ sample_space, tensorize_space, unflatten_tensorized_space, - untensorize_space + untensorize_space, ) diff --git a/skrl/utils/spaces/torch/spaces.py b/skrl/utils/spaces/torch/spaces.py index 579bf33c..1ee664be 100644 --- a/skrl/utils/spaces/torch/spaces.py +++ b/skrl/utils/spaces/torch/spaces.py @@ -34,6 +34,7 @@ def convert_gym_space(space: "gym.Space", squeeze_batch_dimension: bool = False) return spaces.Dict(spaces={k: convert_gym_space(v) for k, v in space.spaces.items()}) raise ValueError(f"Unsupported space ({space})") + def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, torch.device]] = None) -> Any: """Convert the sample/value items of a given gymnasium space to PyTorch tensors. @@ -86,6 +87,7 @@ def tensorize_space(space: spaces.Space, x: Any, device: Optional[Union[str, tor return tuple([tensorize_space(s, _x, device) for s, _x in zip(space, x)]) raise ValueError(f"Unsupported space ({space})") + def untensorize_space(space: spaces.Space, x: Any, squeeze_batch_dimension: bool = True) -> Any: """Convert a tensorized space to a gymnasium space with expected sample/value item types. @@ -134,6 +136,7 @@ def untensorize_space(space: spaces.Space, x: Any, squeeze_batch_dimension: bool return tuple([untensorize_space(s, _x, squeeze_batch_dimension) for s, _x in zip(space, x)]) raise ValueError(f"Unsupported space ({space})") + def flatten_tensorized_space(x: Any) -> torch.Tensor: """Flatten a tensorized space. @@ -150,12 +153,13 @@ def flatten_tensorized_space(x: Any) -> torch.Tensor: # composite spaces # Dict elif isinstance(x, dict): - return torch.cat([flatten_tensorized_space(x[k])for k in sorted(x.keys())], dim=-1) + return torch.cat([flatten_tensorized_space(x[k]) for k in sorted(x.keys())], dim=-1) # Tuple elif type(x) in [list, tuple]: return torch.cat([flatten_tensorized_space(_x) for _x in x], dim=-1) raise ValueError(f"Unsupported sample/value type ({type(x)})") + def unflatten_tensorized_space(space: Union[spaces.Space, Sequence[int], int], x: torch.Tensor) -> Any: """Unflatten a tensor to create a tensorized space. @@ -199,6 +203,7 @@ def unflatten_tensorized_space(space: Union[spaces.Space, Sequence[int], int], x return output raise ValueError(f"Unsupported space ({space})") + def compute_space_size(space: Union[spaces.Space, Sequence[int], int], occupied_size: bool = False) -> int: """Get the size (number of elements) of a space. @@ -232,7 +237,10 @@ def compute_space_size(space: Union[spaces.Space, Sequence[int], int], occupied_ # gymnasium computation return gymnasium.spaces.flatdim(space) -def sample_space(space: spaces.Space, batch_size: int = 1, backend: str = Literal["numpy", "torch"], device = None) -> Any: + +def sample_space( + space: spaces.Space, batch_size: int = 1, backend: str = Literal["numpy", "torch"], device=None +) -> Any: """Generates a random sample from the specified space. :param space: Gymnasium space. From 540a0a7837ae2c8702bfe39768d791d409c9e14b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 4 Nov 2024 16:50:37 -0500 Subject: [PATCH 6/8] Apply black format to tests folder --- tests/jax/test_jax_model_instantiators.py | 92 ++++---- ...test_jax_model_instantiators_definition.py | 53 ++--- tests/jax/test_jax_utils_spaces.py | 13 +- tests/jax/test_jax_wrapper_gym.py | 1 + tests/jax/test_jax_wrapper_gymnasium.py | 1 + tests/jax/test_jax_wrapper_isaacgym.py | 2 +- tests/jax/test_jax_wrapper_isaaclab.py | 37 +++- tests/jax/test_jax_wrapper_omniisaacgym.py | 17 +- tests/jax/test_jax_wrapper_pettingzoo.py | 21 +- tests/stategies.py | 5 +- tests/test_agents.py | 49 +++-- tests/test_envs.py | 15 +- tests/test_examples_deepmind.py | 7 +- tests/test_examples_gym.py | 9 +- tests/test_examples_gymnasium.py | 14 +- tests/test_examples_isaac_orbit.py | 9 +- tests/test_examples_isaacgym.py | 7 +- tests/test_examples_isaacsim.py | 4 +- tests/test_examples_omniisaacgym.py | 9 +- tests/test_examples_robosuite.py | 4 +- tests/test_examples_shimmy.py | 12 +- tests/test_jax_memories_memory.py | 47 ++-- tests/test_memories.py | 20 +- tests/test_model_instantiators.py | 7 +- tests/test_resources_noises.py | 11 +- tests/test_resources_preprocessors.py | 18 +- tests/test_resources_schedulers.py | 5 +- tests/test_trainers.py | 9 +- tests/torch/test_torch_model_instantiators.py | 200 ++++++++++-------- ...st_torch_model_instantiators_definition.py | 113 +++++----- tests/torch/test_torch_utils_spaces.py | 13 +- tests/torch/test_torch_wrapper_deepmind.py | 5 +- tests/torch/test_torch_wrapper_gym.py | 1 + tests/torch/test_torch_wrapper_gymnasium.py | 1 + tests/torch/test_torch_wrapper_isaacgym.py | 2 +- tests/torch/test_torch_wrapper_isaaclab.py | 48 +++-- .../torch/test_torch_wrapper_omniisaacgym.py | 17 +- tests/torch/test_torch_wrapper_pettingzoo.py | 22 +- tests/utils.py | 10 +- 39 files changed, 576 insertions(+), 354 deletions(-) diff --git a/tests/jax/test_jax_model_instantiators.py b/tests/jax/test_jax_model_instantiators.py index 51cd2aaf..96e044d4 100644 --- a/tests/jax/test_jax_model_instantiators.py +++ b/tests/jax/test_jax_model_instantiators.py @@ -11,23 +11,27 @@ from skrl.utils.model_instantiators.jax import Shape, categorical_model, deterministic_model, gaussian_model -@hypothesis.given(observation_space_size=st.integers(min_value=1, max_value=10), - action_space_size=st.integers(min_value=1, max_value=10)) +@hypothesis.given( + observation_space_size=st.integers(min_value=1, max_value=10), + action_space_size=st.integers(min_value=1, max_value=10), +) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) @pytest.mark.parametrize("device", [None, "cpu", "cuda:0"]) def test_categorical_model(capsys, observation_space_size, action_space_size, device): observation_space = gym.spaces.Box(np.array([-1] * observation_space_size), np.array([1] * observation_space_size)) action_space = gym.spaces.Discrete(action_space_size) # TODO: randomize all parameters - model = categorical_model(observation_space=observation_space, - action_space=action_space, - device=device, - unnormalized_log_prob=True, - input_shape=Shape.STATES, - hiddens=[256, 256], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None) + model = categorical_model( + observation_space=observation_space, + action_space=action_space, + device=device, + unnormalized_log_prob=True, + input_shape=Shape.STATES, + hiddens=[256, 256], + hidden_activation=["relu", "relu"], + output_shape=Shape.ACTIONS, + output_activation=None, + ) model.init_state_dict("model") with jax.default_device(model.device): @@ -35,24 +39,29 @@ def test_categorical_model(capsys, observation_space_size, action_space_size, de output = model.act({"states": observations}) assert output[0].shape == (10, 1) -@hypothesis.given(observation_space_size=st.integers(min_value=1, max_value=10), - action_space_size=st.integers(min_value=1, max_value=10)) + +@hypothesis.given( + observation_space_size=st.integers(min_value=1, max_value=10), + action_space_size=st.integers(min_value=1, max_value=10), +) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) @pytest.mark.parametrize("device", [None, "cpu", "cuda:0"]) def test_deterministic_model(capsys, observation_space_size, action_space_size, device): observation_space = gym.spaces.Box(np.array([-1] * observation_space_size), np.array([1] * observation_space_size)) action_space = gym.spaces.Box(np.array([-1] * action_space_size), np.array([1] * action_space_size)) # TODO: randomize all parameters - model = deterministic_model(observation_space=observation_space, - action_space=action_space, - device=device, - clip_actions=False, - input_shape=Shape.STATES, - hiddens=[256, 256], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1) + model = deterministic_model( + observation_space=observation_space, + action_space=action_space, + device=device, + clip_actions=False, + input_shape=Shape.STATES, + hiddens=[256, 256], + hidden_activation=["relu", "relu"], + output_shape=Shape.ACTIONS, + output_activation=None, + output_scale=1, + ) model.init_state_dict("model") with jax.default_device(model.device): @@ -60,28 +69,33 @@ def test_deterministic_model(capsys, observation_space_size, action_space_size, output = model.act({"states": observations}) assert output[0].shape == (10, model.num_actions) -@hypothesis.given(observation_space_size=st.integers(min_value=1, max_value=10), - action_space_size=st.integers(min_value=1, max_value=10)) + +@hypothesis.given( + observation_space_size=st.integers(min_value=1, max_value=10), + action_space_size=st.integers(min_value=1, max_value=10), +) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) @pytest.mark.parametrize("device", [None, "cpu", "cuda:0"]) def test_gaussian_model(capsys, observation_space_size, action_space_size, device): observation_space = gym.spaces.Box(np.array([-1] * observation_space_size), np.array([1] * observation_space_size)) action_space = gym.spaces.Box(np.array([-1] * action_space_size), np.array([1] * action_space_size)) # TODO: randomize all parameters - model = gaussian_model(observation_space=observation_space, - action_space=action_space, - device=device, - clip_actions=False, - clip_log_std=True, - min_log_std=-20, - max_log_std=2, - initial_log_std=0, - input_shape=Shape.STATES, - hiddens=[256, 256], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1) + model = gaussian_model( + observation_space=observation_space, + action_space=action_space, + device=device, + clip_actions=False, + clip_log_std=True, + min_log_std=-20, + max_log_std=2, + initial_log_std=0, + input_shape=Shape.STATES, + hiddens=[256, 256], + hidden_activation=["relu", "relu"], + output_shape=Shape.ACTIONS, + output_activation=None, + output_scale=1, + ) model.init_state_dict("model") with jax.default_device(model.device): diff --git a/tests/jax/test_jax_model_instantiators_definition.py b/tests/jax/test_jax_model_instantiators_definition.py index f6992394..6c2f3de9 100644 --- a/tests/jax/test_jax_model_instantiators_definition.py +++ b/tests/jax/test_jax_model_instantiators_definition.py @@ -22,6 +22,7 @@ def test_get_activation_function(capsys): assert activation is not None, f"{item} -> None" exec(f"{activation}(x)", _globals, {}) + def test_parse_input(capsys): # check for Shape enum (compatibility with prior versions) for input in [Shape.STATES, Shape.OBSERVATIONS, Shape.ACTIONS, Shape.STATES_ACTIONS]: @@ -43,6 +44,7 @@ def test_parse_input(capsys): output = _parse_input(str(input)) assert output.replace("'", '"') == statement, f"'{output}' != '{statement}'" + def test_generate_modules(capsys): _globals = {"nn": flax.linen} @@ -138,6 +140,7 @@ def test_generate_modules(capsys): assert isinstance(container, flax.linen.Sequential) assert len(container.layers) == 2 + def test_gaussian_model(capsys): device = "cpu" observation_space = gym.spaces.Box(np.array([-1] * 5), np.array([1] * 5)) @@ -161,19 +164,15 @@ def test_gaussian_model(capsys): """ content = yaml.safe_load(content) # source - model = gaussian_model(observation_space=observation_space, - action_space=action_space, - device=device, - return_source=True, - **content) + model = gaussian_model( + observation_space=observation_space, action_space=action_space, device=device, return_source=True, **content + ) with capsys.disabled(): print(model) # instance - model = gaussian_model(observation_space=observation_space, - action_space=action_space, - device=device, - return_source=False, - **content) + model = gaussian_model( + observation_space=observation_space, action_space=action_space, device=device, return_source=False, **content + ) model.init_state_dict("model") with capsys.disabled(): print(model) @@ -182,6 +181,7 @@ def test_gaussian_model(capsys): output = model.act({"states": observations}) assert output[0].shape == (10, 2) + def test_deterministic_model(capsys): device = "cpu" observation_space = gym.spaces.Box(np.array([-1] * 5), np.array([1] * 5)) @@ -202,19 +202,15 @@ def test_deterministic_model(capsys): """ content = yaml.safe_load(content) # source - model = deterministic_model(observation_space=observation_space, - action_space=action_space, - device=device, - return_source=True, - **content) + model = deterministic_model( + observation_space=observation_space, action_space=action_space, device=device, return_source=True, **content + ) with capsys.disabled(): print(model) # instance - model = deterministic_model(observation_space=observation_space, - action_space=action_space, - device=device, - return_source=False, - **content) + model = deterministic_model( + observation_space=observation_space, action_space=action_space, device=device, return_source=False, **content + ) model.init_state_dict("model") with capsys.disabled(): print(model) @@ -223,6 +219,7 @@ def test_deterministic_model(capsys): output = model.act({"states": observations}) assert output[0].shape == (10, 3) + def test_categorical_model(capsys): device = "cpu" observation_space = gym.spaces.Box(np.array([-1] * 5), np.array([1] * 5)) @@ -242,19 +239,15 @@ def test_categorical_model(capsys): """ content = yaml.safe_load(content) # source - model = categorical_model(observation_space=observation_space, - action_space=action_space, - device=device, - return_source=True, - **content) + model = categorical_model( + observation_space=observation_space, action_space=action_space, device=device, return_source=True, **content + ) with capsys.disabled(): print(model) # instance - model = categorical_model(observation_space=observation_space, - action_space=action_space, - device=device, - return_source=False, - **content) + model = categorical_model( + observation_space=observation_space, action_space=action_space, device=device, return_source=False, **content + ) model.init_state_dict("model") with capsys.disabled(): print(model) diff --git a/tests/jax/test_jax_utils_spaces.py b/tests/jax/test_jax_utils_spaces.py index 44d5d983..2a845193 100644 --- a/tests/jax/test_jax_utils_spaces.py +++ b/tests/jax/test_jax_utils_spaces.py @@ -15,7 +15,7 @@ sample_space, tensorize_space, unflatten_tensorized_space, - untensorize_space + untensorize_space, ) from ..stategies import gym_space_stategy, gymnasium_space_stategy @@ -29,6 +29,7 @@ def _check_backend(x, backend): else: raise ValueError(f"Invalid backend type: {backend}") + def check_sampled_space(space, x, n, backend): if isinstance(space, gymnasium.spaces.Box): _check_backend(x, backend) @@ -66,6 +67,7 @@ def occupied_size(s): space_size = compute_space_size(space, occupied_size=True) assert space_size == occupied_size(space) + @hypothesis.given(space=gymnasium_space_stategy()) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_tensorize_space(capsys, space: gymnasium.spaces.Space): @@ -97,6 +99,7 @@ def check_tensorized_space(s, x, n): tensorized_space = tensorize_space(space, sampled_space) check_tensorized_space(space, tensorized_space, 5) + @hypothesis.given(space=gymnasium_space_stategy()) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_untensorize_space(capsys, space: gymnasium.spaces.Space): @@ -108,7 +111,9 @@ def check_untensorized_space(s, x, squeeze_batch_dimension): assert isinstance(x, (np.ndarray, int)) assert isinstance(x, int) if squeeze_batch_dimension else x.shape == (1, 1) elif isinstance(s, gymnasium.spaces.MultiDiscrete): - assert isinstance(x, np.ndarray) and x.shape == s.nvec.shape if squeeze_batch_dimension else (1, *s.nvec.shape) + assert ( + isinstance(x, np.ndarray) and x.shape == s.nvec.shape if squeeze_batch_dimension else (1, *s.nvec.shape) + ) elif isinstance(s, gymnasium.spaces.Dict): list(map(check_untensorized_space, s.values(), x.values(), [squeeze_batch_dimension] * len(s))) elif isinstance(s, gymnasium.spaces.Tuple): @@ -124,6 +129,7 @@ def check_untensorized_space(s, x, squeeze_batch_dimension): untensorized_space = untensorize_space(space, tensorized_space, squeeze_batch_dimension=True) check_untensorized_space(space, untensorized_space, squeeze_batch_dimension=True) + @hypothesis.given(space=gymnasium_space_stategy(), batch_size=st.integers(min_value=1, max_value=10)) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_sample_space(capsys, space: gymnasium.spaces.Space, batch_size: int): @@ -134,6 +140,7 @@ def test_sample_space(capsys, space: gymnasium.spaces.Space, batch_size: int): sampled_space = sample_space(space, batch_size, backend="jax") check_sampled_space(space, sampled_space, batch_size, backend="jax") + @hypothesis.given(space=gymnasium_space_stategy()) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_flatten_tensorized_space(capsys, space: gymnasium.spaces.Space): @@ -147,6 +154,7 @@ def test_flatten_tensorized_space(capsys, space: gymnasium.spaces.Space): flattened_space = flatten_tensorized_space(tensorized_space) assert flattened_space.shape == (5, space_size) + @hypothesis.given(space=gymnasium_space_stategy()) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_unflatten_tensorized_space(capsys, space: gymnasium.spaces.Space): @@ -160,6 +168,7 @@ def test_unflatten_tensorized_space(capsys, space: gymnasium.spaces.Space): unflattened_space = unflatten_tensorized_space(space, flattened_space) check_sampled_space(space, unflattened_space, 5, backend="jax") + @hypothesis.given(space=gym_space_stategy()) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_convert_gym_space(capsys, space: gym.spaces.Space): diff --git a/tests/jax/test_jax_wrapper_gym.py b/tests/jax/test_jax_wrapper_gym.py index 7d7042ff..d2ef83cb 100644 --- a/tests/jax/test_jax_wrapper_gym.py +++ b/tests/jax/test_jax_wrapper_gym.py @@ -53,6 +53,7 @@ def test_env(capsys: pytest.CaptureFixture, backend: str): env.close() + @pytest.mark.parametrize("backend", ["jax", "numpy"]) @pytest.mark.parametrize("vectorization_mode", ["async", "sync"]) def test_vectorized_env(capsys: pytest.CaptureFixture, backend: str, vectorization_mode: str): diff --git a/tests/jax/test_jax_wrapper_gymnasium.py b/tests/jax/test_jax_wrapper_gymnasium.py index f8b47d5b..b7676d3f 100644 --- a/tests/jax/test_jax_wrapper_gymnasium.py +++ b/tests/jax/test_jax_wrapper_gymnasium.py @@ -52,6 +52,7 @@ def test_env(capsys: pytest.CaptureFixture, backend: str): env.close() + @pytest.mark.parametrize("backend", ["jax", "numpy"]) @pytest.mark.parametrize("vectorization_mode", ["async", "sync"]) def test_vectorized_env(capsys: pytest.CaptureFixture, backend: str, vectorization_mode: str): diff --git a/tests/jax/test_jax_wrapper_isaacgym.py b/tests/jax/test_jax_wrapper_isaacgym.py index 4629e666..d129bc02 100644 --- a/tests/jax/test_jax_wrapper_isaacgym.py +++ b/tests/jax/test_jax_wrapper_isaacgym.py @@ -30,7 +30,7 @@ def __init__(self, num_states) -> None: self.state_space = gym.spaces.Box(np.ones(self.num_states) * -np.Inf, np.ones(self.num_states) * np.Inf) self.observation_space = gym.spaces.Box(np.ones(self.num_obs) * -np.Inf, np.ones(self.num_obs) * np.Inf) - self.action_space = gym.spaces.Box(np.ones(self.num_actions) * -1., np.ones(self.num_actions) * 1.) + self.action_space = gym.spaces.Box(np.ones(self.num_actions) * -1.0, np.ones(self.num_actions) * 1.0) def reset(self) -> Dict[str, torch.Tensor]: obs_dict = {} diff --git a/tests/jax/test_jax_wrapper_isaaclab.py b/tests/jax/test_jax_wrapper_isaaclab.py index 9b62c5e4..b454fd2d 100644 --- a/tests/jax/test_jax_wrapper_isaaclab.py +++ b/tests/jax/test_jax_wrapper_isaaclab.py @@ -26,6 +26,7 @@ Dict[AgentID, dict], ] + class IsaacLabEnv(gym.Env): def __init__(self, num_states) -> None: self.num_actions = 1 @@ -61,7 +62,9 @@ def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) def step(self, action: torch.Tensor) -> VecEnvStepReturn: assert action.clone().shape == torch.Size([self.num_envs, 1]) - observations = {"policy": torch.ones((self.num_envs, self.num_observations), device=self.device, dtype=torch.float32)} + observations = { + "policy": torch.ones((self.num_envs, self.num_observations), device=self.device, dtype=torch.float32) + } rewards = torch.zeros(self.num_envs, device=self.device, dtype=torch.float32) terminated = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) truncated = torch.zeros_like(terminated) @@ -102,9 +105,7 @@ def _configure_env_spaces(self): if not self.num_states: self.state_space = None if self.num_states < 0: - self.state_space = gym.spaces.Box( - low=-np.inf, high=np.inf, shape=(sum(self.num_observations.values()),) - ) + self.state_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(sum(self.num_observations.values()),)) else: self.state_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.num_states,)) @@ -112,16 +113,28 @@ def _configure_env_spaces(self): def unwrapped(self): return self - def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[Dict[AgentID, ObsType], dict]: - observations = {agent: torch.ones((self.num_envs, self.num_observations[agent]), device=self.device) for agent in self.possible_agents} + def reset( + self, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[Dict[AgentID, ObsType], dict]: + observations = { + agent: torch.ones((self.num_envs, self.num_observations[agent]), device=self.device) + for agent in self.possible_agents + } return observations, self.extras def step(self, action: Dict[AgentID, torch.Tensor]) -> EnvStepReturn: for agent in self.possible_agents: assert action[agent].clone().shape == torch.Size([self.num_envs, self.num_actions[agent]]) - observations = {agent: torch.ones((self.num_envs, self.num_observations[agent]), device=self.device) for agent in self.possible_agents} - rewards = {agent: torch.zeros(self.num_envs, device=self.device, dtype=torch.float32) for agent in self.possible_agents} - terminated = {agent: torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) for agent in self.possible_agents} + observations = { + agent: torch.ones((self.num_envs, self.num_observations[agent]), device=self.device) + for agent in self.possible_agents + } + rewards = { + agent: torch.zeros(self.num_envs, device=self.device, dtype=torch.float32) for agent in self.possible_agents + } + terminated = { + agent: torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) for agent in self.possible_agents + } truncated = {agent: torch.zeros_like(terminated[agent]) for agent in self.possible_agents} return observations, rewards, terminated, truncated, self.extras @@ -183,6 +196,7 @@ def test_env(capsys: pytest.CaptureFixture, backend: str, num_states: int): env.close() + @pytest.mark.parametrize("backend", ["jax", "numpy"]) @pytest.mark.parametrize("num_states", [0, 5]) def test_multi_agent_env(capsys: pytest.CaptureFixture, backend: str, num_states: int): @@ -192,7 +206,10 @@ def test_multi_agent_env(capsys: pytest.CaptureFixture, backend: str, num_states num_envs = 10 num_agents = 3 possible_agents = [f"agent_{i}" for i in range(num_agents)] - action = {f"agent_{i}": jnp.ones((num_envs, i + 10)) if backend == "jax" else np.ones((num_envs, i + 10)) for i in range(num_agents)} + action = { + f"agent_{i}": jnp.ones((num_envs, i + 10)) if backend == "jax" else np.ones((num_envs, i + 10)) + for i in range(num_agents) + } # load wrap the environment original_env = IsaacLabMultiAgentEnv(num_states) diff --git a/tests/jax/test_jax_wrapper_omniisaacgym.py b/tests/jax/test_jax_wrapper_omniisaacgym.py index 5f59014d..caf8e606 100644 --- a/tests/jax/test_jax_wrapper_omniisaacgym.py +++ b/tests/jax/test_jax_wrapper_omniisaacgym.py @@ -28,9 +28,16 @@ def __init__(self, num_states) -> None: self.device = "cpu" # initialize data spaces (defaults to gym.Box) - self.action_space = gym.spaces.Box(np.ones(self.num_actions, dtype=np.float32) * -1.0, np.ones(self.num_actions, dtype=np.float32) * 1.0) - self.observation_space = gym.spaces.Box(np.ones(self.num_observations, dtype=np.float32) * -np.Inf, np.ones(self.num_observations, dtype=np.float32) * np.Inf) - self.state_space = gym.spaces.Box(np.ones(self.num_states, dtype=np.float32) * -np.Inf, np.ones(self.num_states, dtype=np.float32) * np.Inf) + self.action_space = gym.spaces.Box( + np.ones(self.num_actions, dtype=np.float32) * -1.0, np.ones(self.num_actions, dtype=np.float32) * 1.0 + ) + self.observation_space = gym.spaces.Box( + np.ones(self.num_observations, dtype=np.float32) * -np.Inf, + np.ones(self.num_observations, dtype=np.float32) * np.Inf, + ) + self.state_space = gym.spaces.Box( + np.ones(self.num_states, dtype=np.float32) * -np.Inf, np.ones(self.num_states, dtype=np.float32) * np.Inf + ) def reset(self): observations = {"obs": torch.ones((self.num_envs, self.num_observations), device=self.device)} @@ -38,7 +45,9 @@ def reset(self): def step(self, actions): assert actions.clone().shape == torch.Size([self.num_envs, 1]) - observations = {"obs": torch.ones((self.num_envs, self.num_observations), device=self.device, dtype=torch.float32)} + observations = { + "obs": torch.ones((self.num_envs, self.num_observations), device=self.device, dtype=torch.float32) + } rewards = torch.zeros(self.num_envs, device=self.device, dtype=torch.float32) dones = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) return observations, rewards, dones, self.extras diff --git a/tests/jax/test_jax_wrapper_pettingzoo.py b/tests/jax/test_jax_wrapper_pettingzoo.py index d62756fe..93c77df3 100644 --- a/tests/jax/test_jax_wrapper_pettingzoo.py +++ b/tests/jax/test_jax_wrapper_pettingzoo.py @@ -21,7 +21,10 @@ def test_env(capsys: pytest.CaptureFixture, backend: str): num_envs = 1 num_agents = 20 possible_agents = [f"piston_{i}" for i in range(num_agents)] - action = {f"piston_{i}": jnp.ones((num_envs, 1)) if backend == "jax" else np.ones((num_envs, 1)) for i in range(num_agents)} + action = { + f"piston_{i}": jnp.ones((num_envs, 1)) if backend == "jax" else np.ones((num_envs, 1)) + for i in range(num_agents) + } # load wrap the environment original_env = pistonball_v6.parallel_env(n_pistons=num_agents, continuous=True, max_cycles=125) @@ -38,7 +41,11 @@ def test_env(capsys: pytest.CaptureFixture, backend: str): assert isinstance(env.action_spaces, Mapping) and len(env.action_spaces) == num_agents for agent in possible_agents: assert isinstance(env.state_space(agent), gym.Space) and env.state_space(agent).shape == (560, 880, 3) - assert isinstance(env.observation_space(agent), gym.Space) and env.observation_space(agent).shape == (457, 120, 3) + assert isinstance(env.observation_space(agent), gym.Space) and env.observation_space(agent).shape == ( + 457, + 120, + 3, + ) assert isinstance(env.action_space(agent), gym.Space) and env.action_space(agent).shape == (1,) assert isinstance(env.possible_agents, list) and sorted(env.possible_agents) == sorted(possible_agents) assert isinstance(env.num_envs, int) and env.num_envs == num_envs @@ -54,7 +61,10 @@ def test_env(capsys: pytest.CaptureFixture, backend: str): assert isinstance(observation, Mapping) assert isinstance(info, Mapping) for agent in possible_agents: - assert isinstance(observation[agent], Array) and observation[agent].shape == (num_envs, math.prod((457, 120, 3))) + assert isinstance(observation[agent], Array) and observation[agent].shape == ( + num_envs, + math.prod((457, 120, 3)), + ) for _ in range(3): observation, reward, terminated, truncated, info = env.step(action) state = env.state() @@ -65,7 +75,10 @@ def test_env(capsys: pytest.CaptureFixture, backend: str): assert isinstance(truncated, Mapping) assert isinstance(info, Mapping) for agent in possible_agents: - assert isinstance(observation[agent], Array) and observation[agent].shape == (num_envs, math.prod((457, 120, 3))) + assert isinstance(observation[agent], Array) and observation[agent].shape == ( + num_envs, + math.prod((457, 120, 3)), + ) assert isinstance(reward[agent], Array) and reward[agent].shape == (num_envs, 1) assert isinstance(terminated[agent], Array) and terminated[agent].shape == (num_envs, 1) assert isinstance(truncated[agent], Array) and truncated[agent].shape == (num_envs, 1) diff --git a/tests/stategies.py b/tests/stategies.py index 409b9771..4840a541 100644 --- a/tests/stategies.py +++ b/tests/stategies.py @@ -28,11 +28,14 @@ def gymnasium_space_stategy(draw, space_type: str = "", remaining_iterations: in return gymnasium.spaces.Dict(spaces) elif space_type == "Tuple": remaining_iterations -= 1 - spaces = draw(st.lists(gymnasium_space_stategy(remaining_iterations=remaining_iterations), min_size=1, max_size=3)) + spaces = draw( + st.lists(gymnasium_space_stategy(remaining_iterations=remaining_iterations), min_size=1, max_size=3) + ) return gymnasium.spaces.Tuple(spaces) else: raise ValueError(f"Invalid space type: {space_type}") + @st.composite def gym_space_stategy(draw, space_type: str = "", remaining_iterations: int = 5) -> gym.spaces.Space: if not space_type: diff --git a/tests/test_agents.py b/tests/test_agents.py index 2fcb36d5..48a6d3dd 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -23,36 +23,39 @@ @pytest.fixture def classes_and_kwargs(): - return [(A2C, {"models": {"policy": DummyModel()}}), - (AMP, {"models": {"policy": DummyModel()}}), - (CEM, {"models": {"policy": DummyModel()}}), - (DDPG, {"models": {"policy": DummyModel()}}), - (DQN, {"models": {"policy": DummyModel()}}), - (DDQN, {"models": {"policy": DummyModel()}}), - (PPO, {"models": {"policy": DummyModel()}}), - (Q_LEARNING, {"models": {"policy": DummyModel()}}), - (SAC, {"models": {"policy": DummyModel()}}), - (SARSA, {"models": {"policy": DummyModel()}}), - (TD3, {"models": {"policy": DummyModel()}}), - (TRPO, {"models": {"policy": DummyModel()}})] + return [ + (A2C, {"models": {"policy": DummyModel()}}), + (AMP, {"models": {"policy": DummyModel()}}), + (CEM, {"models": {"policy": DummyModel()}}), + (DDPG, {"models": {"policy": DummyModel()}}), + (DQN, {"models": {"policy": DummyModel()}}), + (DDQN, {"models": {"policy": DummyModel()}}), + (PPO, {"models": {"policy": DummyModel()}}), + (Q_LEARNING, {"models": {"policy": DummyModel()}}), + (SAC, {"models": {"policy": DummyModel()}}), + (SARSA, {"models": {"policy": DummyModel()}}), + (TD3, {"models": {"policy": DummyModel()}}), + (TRPO, {"models": {"policy": DummyModel()}}), + ] def test_agent(capsys, classes_and_kwargs): for klass, kwargs in classes_and_kwargs: - cfg = {"learning_starts": 1, - "experiment": {"write_interval": 0}} + cfg = {"learning_starts": 1, "experiment": {"write_interval": 0}} agent: Agent = klass(cfg=cfg, **kwargs) agent.init() agent.pre_interaction(timestep=0, timesteps=1) # agent.act(None, timestep=0, timestesps=1) - agent.record_transition(states=torch.tensor([]), - actions=torch.tensor([]), - rewards=torch.tensor([]), - next_states=torch.tensor([]), - terminated=torch.tensor([]), - truncated=torch.tensor([]), - infos={}, - timestep=0, - timesteps=1) + agent.record_transition( + states=torch.tensor([]), + actions=torch.tensor([]), + rewards=torch.tensor([]), + next_states=torch.tensor([]), + terminated=torch.tensor([]), + truncated=torch.tensor([]), + infos={}, + timestep=0, + timesteps=1, + ) agent.post_interaction(timestep=0, timesteps=1) diff --git a/tests/test_envs.py b/tests/test_envs.py index deac4c4d..00f5480c 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -15,8 +15,19 @@ def classes_and_kwargs(): return [] -@pytest.mark.parametrize("wrapper", ["gym", "gymnasium", "dm", "robosuite", \ - "isaacgym-preview2", "isaacgym-preview3", "isaacgym-preview4", "omniverse-isaacgym"]) +@pytest.mark.parametrize( + "wrapper", + [ + "gym", + "gymnasium", + "dm", + "robosuite", + "isaacgym-preview2", + "isaacgym-preview3", + "isaacgym-preview4", + "omniverse-isaacgym", + ], +) def test_wrap_env(capsys, classes_and_kwargs, wrapper): env = DummyEnv(num_envs=1) diff --git a/tests/test_examples_deepmind.py b/tests/test_examples_deepmind.py index 98222a3b..6cdf2a92 100644 --- a/tests/test_examples_deepmind.py +++ b/tests/test_examples_deepmind.py @@ -8,9 +8,10 @@ EXAMPLE_DIR = "deepmind" -SCRIPTS = ["dm_suite_cartpole_swingup_ddpg.py", - "dm_manipulation_stack_sac.py", ""] -EXAMPLES_DIR = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples")) +SCRIPTS = ["dm_suite_cartpole_swingup_ddpg.py", "dm_manipulation_stack_sac.py", ""] +EXAMPLES_DIR = os.path.abspath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples") +) COMMANDS = [f"python {os.path.join(EXAMPLES_DIR, EXAMPLE_DIR, script)}" for script in SCRIPTS] diff --git a/tests/test_examples_gym.py b/tests/test_examples_gym.py index 479fe841..73ff3a19 100644 --- a/tests/test_examples_gym.py +++ b/tests/test_examples_gym.py @@ -8,11 +8,10 @@ EXAMPLE_DIR = "gym" -SCRIPTS = ["ddpg_gym_pendulum.py", - "cem_gym_cartpole.py", - "dqn_gym_cartpole.py", - "q_learning_gym_frozen_lake.py"] -EXAMPLES_DIR = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples")) +SCRIPTS = ["ddpg_gym_pendulum.py", "cem_gym_cartpole.py", "dqn_gym_cartpole.py", "q_learning_gym_frozen_lake.py"] +EXAMPLES_DIR = os.path.abspath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples") +) COMMANDS = [f"python {os.path.join(EXAMPLES_DIR, EXAMPLE_DIR, script)}" for script in SCRIPTS] diff --git a/tests/test_examples_gymnasium.py b/tests/test_examples_gymnasium.py index a643a5a6..96b0a5a3 100644 --- a/tests/test_examples_gymnasium.py +++ b/tests/test_examples_gymnasium.py @@ -8,11 +8,15 @@ EXAMPLE_DIR = "gymnasium" -SCRIPTS = ["ddpg_gymnasium_pendulum.py", - "cem_gymnasium_cartpole.py", - "dqn_gymnasium_cartpole.py", - "q_learning_gymnasium_frozen_lake.py"] -EXAMPLES_DIR = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples")) +SCRIPTS = [ + "ddpg_gymnasium_pendulum.py", + "cem_gymnasium_cartpole.py", + "dqn_gymnasium_cartpole.py", + "q_learning_gymnasium_frozen_lake.py", +] +EXAMPLES_DIR = os.path.abspath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples") +) COMMANDS = [f"python {os.path.join(EXAMPLES_DIR, EXAMPLE_DIR, script)}" for script in SCRIPTS] diff --git a/tests/test_examples_isaac_orbit.py b/tests/test_examples_isaac_orbit.py index a9f17bee..8e89dfea 100644 --- a/tests/test_examples_isaac_orbit.py +++ b/tests/test_examples_isaac_orbit.py @@ -13,8 +13,13 @@ EXAMPLE_DIR = "isaacorbit" SCRIPTS = ["ppo_cartpole.py"] -EXAMPLES_DIR = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples")) -COMMANDS = [f"{PYTHON_ENVIRONMENT} {os.path.join(EXAMPLES_DIR, EXAMPLE_DIR, script)} --headless --num_envs 64" for script in SCRIPTS] +EXAMPLES_DIR = os.path.abspath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples") +) +COMMANDS = [ + f"{PYTHON_ENVIRONMENT} {os.path.join(EXAMPLES_DIR, EXAMPLE_DIR, script)} --headless --num_envs 64" + for script in SCRIPTS +] @pytest.mark.parametrize("command", COMMANDS) diff --git a/tests/test_examples_isaacgym.py b/tests/test_examples_isaacgym.py index 523a86db..2ca0f1e1 100644 --- a/tests/test_examples_isaacgym.py +++ b/tests/test_examples_isaacgym.py @@ -8,9 +8,10 @@ EXAMPLE_DIR = "isaacgym" -SCRIPTS = ["ppo_cartpole.py", - "trpo_cartpole.py"] -EXAMPLES_DIR = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples")) +SCRIPTS = ["ppo_cartpole.py", "trpo_cartpole.py"] +EXAMPLES_DIR = os.path.abspath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples") +) COMMANDS = [f"python {os.path.join(EXAMPLES_DIR, EXAMPLE_DIR, script)} headless=True num_envs=64" for script in SCRIPTS] diff --git a/tests/test_examples_isaacsim.py b/tests/test_examples_isaacsim.py index ef367110..5b1fdd26 100644 --- a/tests/test_examples_isaacsim.py +++ b/tests/test_examples_isaacsim.py @@ -13,7 +13,9 @@ EXAMPLE_DIR = "isaacsim" SCRIPTS = ["cartpole_example_skrl.py"] -EXAMPLES_DIR = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples")) +EXAMPLES_DIR = os.path.abspath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples") +) COMMANDS = [f"{PYTHON_ENVIRONMENT} {os.path.join(EXAMPLES_DIR, EXAMPLE_DIR, script)}" for script in SCRIPTS] diff --git a/tests/test_examples_omniisaacgym.py b/tests/test_examples_omniisaacgym.py index d92f7c3c..d0ccddf5 100644 --- a/tests/test_examples_omniisaacgym.py +++ b/tests/test_examples_omniisaacgym.py @@ -13,8 +13,13 @@ EXAMPLE_DIR = "omniisaacgym" SCRIPTS = ["ppo_cartpole.py"] -EXAMPLES_DIR = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples")) -COMMANDS = [f"{PYTHON_ENVIRONMENT} {os.path.join(EXAMPLES_DIR, EXAMPLE_DIR, script)} headless=True num_envs=64" for script in SCRIPTS] +EXAMPLES_DIR = os.path.abspath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples") +) +COMMANDS = [ + f"{PYTHON_ENVIRONMENT} {os.path.join(EXAMPLES_DIR, EXAMPLE_DIR, script)} headless=True num_envs=64" + for script in SCRIPTS +] @pytest.mark.parametrize("command", COMMANDS) diff --git a/tests/test_examples_robosuite.py b/tests/test_examples_robosuite.py index a03b2f61..c238538d 100644 --- a/tests/test_examples_robosuite.py +++ b/tests/test_examples_robosuite.py @@ -9,7 +9,9 @@ EXAMPLE_DIR = "robosuite" SCRIPTS = [] -EXAMPLES_DIR = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples")) +EXAMPLES_DIR = os.path.abspath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples") +) COMMANDS = [f"python {os.path.join(EXAMPLES_DIR, EXAMPLE_DIR, script)}" for script in SCRIPTS] diff --git a/tests/test_examples_shimmy.py b/tests/test_examples_shimmy.py index 3283d239..623ad53d 100644 --- a/tests/test_examples_shimmy.py +++ b/tests/test_examples_shimmy.py @@ -8,10 +8,14 @@ EXAMPLE_DIR = "shimmy" -SCRIPTS = ["dqn_shimmy_atari_pong.py", - "sac_shimmy_dm_control_acrobot_swingup_sparse.py", - "ddpg_openai_gym_compatibility_pendulum.py"] -EXAMPLES_DIR = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples")) +SCRIPTS = [ + "dqn_shimmy_atari_pong.py", + "sac_shimmy_dm_control_acrobot_swingup_sparse.py", + "ddpg_openai_gym_compatibility_pendulum.py", +] +EXAMPLES_DIR = os.path.abspath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "docs", "source", "examples") +) COMMANDS = [f"python {os.path.join(EXAMPLES_DIR, EXAMPLE_DIR, script)}" for script in SCRIPTS] diff --git a/tests/test_jax_memories_memory.py b/tests/test_jax_memories_memory.py index 586205b6..1f0d996f 100644 --- a/tests/test_jax_memories_memory.py +++ b/tests/test_jax_memories_memory.py @@ -46,13 +46,25 @@ def test_tensor_names(self): # test memory.get_tensor_by_name for name, size, dtype in zip(self.names, self.sizes, self.dtypes): tensor = memory.get_tensor_by_name(name, keepdim=True) - self.assertSequenceEqual(memory.get_tensor_by_name(name, keepdim=True).shape, (memory_size, num_envs, size), "get_tensor_by_name(..., keepdim=True)") - self.assertSequenceEqual(memory.get_tensor_by_name(name, keepdim=False).shape, (memory_size * num_envs, size), "get_tensor_by_name(..., keepdim=False)") - self.assertEqual(memory.get_tensor_by_name(name, keepdim=True).dtype, dtype, "get_tensor_by_name(...).dtype") + self.assertSequenceEqual( + memory.get_tensor_by_name(name, keepdim=True).shape, + (memory_size, num_envs, size), + "get_tensor_by_name(..., keepdim=True)", + ) + self.assertSequenceEqual( + memory.get_tensor_by_name(name, keepdim=False).shape, + (memory_size * num_envs, size), + "get_tensor_by_name(..., keepdim=False)", + ) + self.assertEqual( + memory.get_tensor_by_name(name, keepdim=True).dtype, dtype, "get_tensor_by_name(...).dtype" + ) # test memory.set_tensor_by_name for name, size, dtype in zip(self.names, self.sizes, self.raw_dtypes): - new_tensor = jnp.arange(memory_size * num_envs * size).reshape(memory_size, num_envs, size).astype(dtype) + new_tensor = ( + jnp.arange(memory_size * num_envs * size).reshape(memory_size, num_envs, size).astype(dtype) + ) memory.set_tensor_by_name(name, new_tensor) tensor = memory.get_tensor_by_name(name, keepdim=True) self.assertTrue((tensor == new_tensor).all().item(), "set_tensor_by_name(...)") @@ -68,30 +80,39 @@ def test_sample(self): # fill memory for name, size, dtype in zip(self.names, self.sizes, self.raw_dtypes): - new_tensor = jnp.arange(memory_size * num_envs * size).reshape(memory_size, num_envs, size).astype(dtype) + new_tensor = ( + jnp.arange(memory_size * num_envs * size).reshape(memory_size, num_envs, size).astype(dtype) + ) memory.set_tensor_by_name(name, new_tensor) # test memory.sample_all for i, mini_batches in enumerate(self.mini_batches): samples = memory.sample_all(self.names, mini_batches=mini_batches) for sample, name, size in zip(samples[i], self.names, self.sizes): - self.assertSequenceEqual(sample.shape, (memory_size * num_envs, size), f"sample_all(...).shape with mini_batches={mini_batches}") + self.assertSequenceEqual( + sample.shape, + (memory_size * num_envs, size), + f"sample_all(...).shape with mini_batches={mini_batches}", + ) tensor = memory.get_tensor_by_name(name, keepdim=True) - self.assertTrue((sample.reshape(memory_size, num_envs, size) == tensor).all().item(), f"sample_all(...) with mini_batches={mini_batches}") + self.assertTrue( + (sample.reshape(memory_size, num_envs, size) == tensor).all().item(), + f"sample_all(...) with mini_batches={mini_batches}", + ) -if __name__ == '__main__': +if __name__ == "__main__": import sys - if not sys.argv[-1] == '--debug': - raise RuntimeError('Test can only be runned manually with --debug flag') + if not sys.argv[-1] == "--debug": + raise RuntimeError("Test can only be runned manually with --debug flag") test = TestCase() test.setUp() for method in dir(test): - if method.startswith('test_'): - print('Running test: {}'.format(method)) + if method.startswith("test_"): + print("Running test: {}".format(method)) getattr(test, method)() test.tearDown() - print('All tests passed.') + print("All tests passed.") diff --git a/tests/test_memories.py b/tests/test_memories.py index 1cf2da2c..746daea5 100644 --- a/tests/test_memories.py +++ b/tests/test_memories.py @@ -17,7 +17,9 @@ def classes_and_kwargs(): @pytest.mark.parametrize("device", [None, "cpu", "cuda:0"]) def test_device(capsys, classes_and_kwargs, device): - _device = torch.device(device) if device is not None else torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + _device = ( + torch.device(device) if device is not None else torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + ) for klass, kwargs in classes_and_kwargs: try: @@ -30,7 +32,12 @@ def test_device(capsys, classes_and_kwargs, device): assert memory.device == _device # defined device -@hypothesis.given(names=st.sets(st.text(alphabet=string.ascii_letters + string.digits + "_", min_size=1, max_size=10), min_size=1, max_size=10)) + +@hypothesis.given( + names=st.sets( + st.text(alphabet=string.ascii_letters + string.digits + "_", min_size=1, max_size=10), min_size=1, max_size=10 + ) +) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_create_tensors(capsys, classes_and_kwargs, names): for klass, kwargs in classes_and_kwargs: @@ -41,9 +48,12 @@ def test_create_tensors(capsys, classes_and_kwargs, names): assert memory.get_tensor_names() == sorted(names) -@hypothesis.given(memory_size=st.integers(min_value=1, max_value=100), - num_envs=st.integers(min_value=1, max_value=10), - num_samples=st.integers(min_value=1, max_value=500)) + +@hypothesis.given( + memory_size=st.integers(min_value=1, max_value=100), + num_envs=st.integers(min_value=1, max_value=10), + num_samples=st.integers(min_value=1, max_value=500), +) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_add_samples(capsys, classes_and_kwargs, memory_size, num_envs, num_samples): for klass, kwargs in classes_and_kwargs: diff --git a/tests/test_model_instantiators.py b/tests/test_model_instantiators.py index 75ec2bf9..ab8756c7 100644 --- a/tests/test_model_instantiators.py +++ b/tests/test_model_instantiators.py @@ -11,16 +11,13 @@ categorical_model, deterministic_model, gaussian_model, - multivariate_gaussian_model + multivariate_gaussian_model, ) @pytest.fixture def classes_and_kwargs(): - return [(categorical_model, {}), - (deterministic_model, {}), - (gaussian_model, {}), - (multivariate_gaussian_model, {})] + return [(categorical_model, {}), (deterministic_model, {}), (gaussian_model, {}), (multivariate_gaussian_model, {})] def test_models(capsys, classes_and_kwargs): diff --git a/tests/test_resources_noises.py b/tests/test_resources_noises.py index 108da62e..927eb8b8 100644 --- a/tests/test_resources_noises.py +++ b/tests/test_resources_noises.py @@ -10,13 +10,17 @@ @pytest.fixture def classes_and_kwargs(): - return [(GaussianNoise, {"mean": 0, "std": 1}), - (OrnsteinUhlenbeckNoise, {"theta": 0.1, "sigma": 0.2, "base_scale": 0.3})] + return [ + (GaussianNoise, {"mean": 0, "std": 1}), + (OrnsteinUhlenbeckNoise, {"theta": 0.1, "sigma": 0.2, "base_scale": 0.3}), + ] @pytest.mark.parametrize("device", [None, "cpu", "cuda:0"]) def test_device(capsys, classes_and_kwargs, device): - _device = torch.device(device) if device is not None else torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + _device = ( + torch.device(device) if device is not None else torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + ) for klass, kwargs in classes_and_kwargs: try: @@ -31,6 +35,7 @@ def test_device(capsys, classes_and_kwargs, device): assert noise.device == _device # defined device assert output.device == _device # runtime device + @hypothesis.given(size=st.lists(st.integers(min_value=1, max_value=10), max_size=5)) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_sample(capsys, classes_and_kwargs, size): diff --git a/tests/test_resources_preprocessors.py b/tests/test_resources_preprocessors.py index 66f817cb..f340c198 100644 --- a/tests/test_resources_preprocessors.py +++ b/tests/test_resources_preprocessors.py @@ -18,7 +18,9 @@ def classes_and_kwargs(): @pytest.mark.parametrize("device", [None, "cpu", "cuda:0"]) def test_device(capsys, classes_and_kwargs, device): - _device = torch.device(device) if device is not None else torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + _device = ( + torch.device(device) if device is not None else torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + ) for klass, kwargs in classes_and_kwargs: try: @@ -32,10 +34,16 @@ def test_device(capsys, classes_and_kwargs, device): assert preprocessor.device == _device # defined device assert preprocessor(torch.ones(kwargs["size"], device=_device)).device == _device # runtime device -@pytest.mark.parametrize("space_and_size", [(gym.spaces.Box(low=-1, high=1, shape=(2, 3)), 6), - (gymnasium.spaces.Box(low=-1, high=1, shape=(2, 3)), 6), - (gym.spaces.Discrete(n=3), 1), - (gymnasium.spaces.Discrete(n=3), 1)]) + +@pytest.mark.parametrize( + "space_and_size", + [ + (gym.spaces.Box(low=-1, high=1, shape=(2, 3)), 6), + (gymnasium.spaces.Box(low=-1, high=1, shape=(2, 3)), 6), + (gym.spaces.Discrete(n=3), 1), + (gymnasium.spaces.Discrete(n=3), 1), + ], +) def test_forward(capsys, classes_and_kwargs, space_and_size): for klass, kwargs in classes_and_kwargs: space, size = space_and_size diff --git a/tests/test_resources_schedulers.py b/tests/test_resources_schedulers.py index 79b1db32..692710e4 100644 --- a/tests/test_resources_schedulers.py +++ b/tests/test_resources_schedulers.py @@ -13,8 +13,9 @@ def classes_and_kwargs(): return [(KLAdaptiveRL, {})] -@pytest.mark.parametrize("optimizer", [torch.optim.Adam([torch.ones((1,))], lr=0.1), - torch.optim.SGD([torch.ones((1,))], lr=0.1)]) +@pytest.mark.parametrize( + "optimizer", [torch.optim.Adam([torch.ones((1,))], lr=0.1), torch.optim.SGD([torch.ones((1,))], lr=0.1)] +) def test_step(capsys, classes_and_kwargs, optimizer): for klass, kwargs in classes_and_kwargs: scheduler = klass(optimizer, **kwargs) diff --git a/tests/test_trainers.py b/tests/test_trainers.py index f89c5a8b..f4e98265 100644 --- a/tests/test_trainers.py +++ b/tests/test_trainers.py @@ -12,9 +12,11 @@ @pytest.fixture def classes_and_kwargs(): - return [(ManualTrainer, {"cfg": {"timesteps": 100}}), - (ParallelTrainer, {"cfg": {"timesteps": 100}}), - (SequentialTrainer, {"cfg": {"timesteps": 100}})] + return [ + (ManualTrainer, {"cfg": {"timesteps": 100}}), + (ParallelTrainer, {"cfg": {"timesteps": 100}}), + (SequentialTrainer, {"cfg": {"timesteps": 100}}), + ] def test_train(capsys, classes_and_kwargs): @@ -26,6 +28,7 @@ def test_train(capsys, classes_and_kwargs): trainer.train() + def test_eval(capsys, classes_and_kwargs): env = DummyEnv(num_envs=1) agent = DummyAgent() diff --git a/tests/torch/test_torch_model_instantiators.py b/tests/torch/test_torch_model_instantiators.py index 44ff6c22..e84e5ec1 100644 --- a/tests/torch/test_torch_model_instantiators.py +++ b/tests/torch/test_torch_model_instantiators.py @@ -13,151 +13,175 @@ deterministic_model, gaussian_model, multivariate_gaussian_model, - shared_model + shared_model, ) -@hypothesis.given(observation_space_size=st.integers(min_value=1, max_value=10), - action_space_size=st.integers(min_value=1, max_value=10)) +@hypothesis.given( + observation_space_size=st.integers(min_value=1, max_value=10), + action_space_size=st.integers(min_value=1, max_value=10), +) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) @pytest.mark.parametrize("device", [None, "cpu", "cuda:0"]) def test_categorical_model(capsys, observation_space_size, action_space_size, device): observation_space = gym.spaces.Box(np.array([-1] * observation_space_size), np.array([1] * observation_space_size)) action_space = gym.spaces.Discrete(action_space_size) # TODO: randomize all parameters - model = categorical_model(observation_space=observation_space, - action_space=action_space, - device=device, - unnormalized_log_prob=True, - input_shape=Shape.STATES, - hiddens=[256, 256], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None) + model = categorical_model( + observation_space=observation_space, + action_space=action_space, + device=device, + unnormalized_log_prob=True, + input_shape=Shape.STATES, + hiddens=[256, 256], + hidden_activation=["relu", "relu"], + output_shape=Shape.ACTIONS, + output_activation=None, + ) model.to(device=device) observations = torch.ones((10, model.num_observations), device=device) output = model.act({"states": observations}) assert output[0].shape == (10, 1) -@hypothesis.given(observation_space_size=st.integers(min_value=1, max_value=10), - action_space_size=st.integers(min_value=1, max_value=10)) + +@hypothesis.given( + observation_space_size=st.integers(min_value=1, max_value=10), + action_space_size=st.integers(min_value=1, max_value=10), +) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) @pytest.mark.parametrize("device", [None, "cpu", "cuda:0"]) def test_deterministic_model(capsys, observation_space_size, action_space_size, device): observation_space = gym.spaces.Box(np.array([-1] * observation_space_size), np.array([1] * observation_space_size)) action_space = gym.spaces.Box(np.array([-1] * action_space_size), np.array([1] * action_space_size)) # TODO: randomize all parameters - model = deterministic_model(observation_space=observation_space, - action_space=action_space, - device=device, - clip_actions=False, - input_shape=Shape.STATES, - hiddens=[256, 256], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1) + model = deterministic_model( + observation_space=observation_space, + action_space=action_space, + device=device, + clip_actions=False, + input_shape=Shape.STATES, + hiddens=[256, 256], + hidden_activation=["relu", "relu"], + output_shape=Shape.ACTIONS, + output_activation=None, + output_scale=1, + ) model.to(device=device) observations = torch.ones((10, model.num_observations), device=device) output = model.act({"states": observations}) assert output[0].shape == (10, model.num_actions) -@hypothesis.given(observation_space_size=st.integers(min_value=1, max_value=10), - action_space_size=st.integers(min_value=1, max_value=10)) + +@hypothesis.given( + observation_space_size=st.integers(min_value=1, max_value=10), + action_space_size=st.integers(min_value=1, max_value=10), +) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) @pytest.mark.parametrize("device", [None, "cpu", "cuda:0"]) def test_gaussian_model(capsys, observation_space_size, action_space_size, device): observation_space = gym.spaces.Box(np.array([-1] * observation_space_size), np.array([1] * observation_space_size)) action_space = gym.spaces.Box(np.array([-1] * action_space_size), np.array([1] * action_space_size)) # TODO: randomize all parameters - model = gaussian_model(observation_space=observation_space, - action_space=action_space, - device=device, - clip_actions=False, - clip_log_std=True, - min_log_std=-20, - max_log_std=2, - initial_log_std=0, - input_shape=Shape.STATES, - hiddens=[256, 256], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1) + model = gaussian_model( + observation_space=observation_space, + action_space=action_space, + device=device, + clip_actions=False, + clip_log_std=True, + min_log_std=-20, + max_log_std=2, + initial_log_std=0, + input_shape=Shape.STATES, + hiddens=[256, 256], + hidden_activation=["relu", "relu"], + output_shape=Shape.ACTIONS, + output_activation=None, + output_scale=1, + ) model.to(device=device) observations = torch.ones((10, model.num_observations), device=device) output = model.act({"states": observations}) assert output[0].shape == (10, model.num_actions) -@hypothesis.given(observation_space_size=st.integers(min_value=1, max_value=10), - action_space_size=st.integers(min_value=1, max_value=10)) + +@hypothesis.given( + observation_space_size=st.integers(min_value=1, max_value=10), + action_space_size=st.integers(min_value=1, max_value=10), +) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) @pytest.mark.parametrize("device", [None, "cpu", "cuda:0"]) def test_multivariate_gaussian_model(capsys, observation_space_size, action_space_size, device): observation_space = gym.spaces.Box(np.array([-1] * observation_space_size), np.array([1] * observation_space_size)) action_space = gym.spaces.Box(np.array([-1] * action_space_size), np.array([1] * action_space_size)) # TODO: randomize all parameters - model = multivariate_gaussian_model(observation_space=observation_space, - action_space=action_space, - device=device, - clip_actions=False, - clip_log_std=True, - min_log_std=-20, - max_log_std=2, - initial_log_std=0, - input_shape=Shape.STATES, - hiddens=[256, 256], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1) + model = multivariate_gaussian_model( + observation_space=observation_space, + action_space=action_space, + device=device, + clip_actions=False, + clip_log_std=True, + min_log_std=-20, + max_log_std=2, + initial_log_std=0, + input_shape=Shape.STATES, + hiddens=[256, 256], + hidden_activation=["relu", "relu"], + output_shape=Shape.ACTIONS, + output_activation=None, + output_scale=1, + ) model.to(device=device) observations = torch.ones((10, model.num_observations), device=device) output = model.act({"states": observations}) assert output[0].shape == (10, model.num_actions) -@hypothesis.given(observation_space_size=st.integers(min_value=1, max_value=10), - action_space_size=st.integers(min_value=1, max_value=10)) + +@hypothesis.given( + observation_space_size=st.integers(min_value=1, max_value=10), + action_space_size=st.integers(min_value=1, max_value=10), +) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) @pytest.mark.parametrize("device", [None, "cpu", "cuda:0"]) def test_shared_model(capsys, observation_space_size, action_space_size, device): observation_space = gym.spaces.Box(np.array([-1] * observation_space_size), np.array([1] * observation_space_size)) action_space = gym.spaces.Box(np.array([-1] * action_space_size), np.array([1] * action_space_size)) # TODO: randomize all parameters - model = shared_model(observation_space=observation_space, - action_space=action_space, - device=device, - structure="", - roles=["policy", "value"], - parameters=[ - { - "clip_actions": False, - "clip_log_std": True, - "min_log_std": -20, - "max_log_std": 2, - "initial_log_std": 0, - "input_shape": Shape.STATES, - "hiddens": [256, 256], - "hidden_activation": ["relu", "relu"], - "output_shape": Shape.ACTIONS, - "output_activation": None, - "output_scale": 1, - }, - { - "clip_actions": False, - "input_shape": Shape.STATES, - "hiddens": [256, 256], - "hidden_activation": ["relu", "relu"], - "output_shape": Shape.ONE, - "output_activation": None, - "output_scale": 1, - } - ], - single_forward_pass=True) + model = shared_model( + observation_space=observation_space, + action_space=action_space, + device=device, + structure="", + roles=["policy", "value"], + parameters=[ + { + "clip_actions": False, + "clip_log_std": True, + "min_log_std": -20, + "max_log_std": 2, + "initial_log_std": 0, + "input_shape": Shape.STATES, + "hiddens": [256, 256], + "hidden_activation": ["relu", "relu"], + "output_shape": Shape.ACTIONS, + "output_activation": None, + "output_scale": 1, + }, + { + "clip_actions": False, + "input_shape": Shape.STATES, + "hiddens": [256, 256], + "hidden_activation": ["relu", "relu"], + "output_shape": Shape.ONE, + "output_activation": None, + "output_scale": 1, + }, + ], + single_forward_pass=True, + ) model.to(device=device) observations = torch.ones((10, model.num_observations), device=device) diff --git a/tests/torch/test_torch_model_instantiators_definition.py b/tests/torch/test_torch_model_instantiators_definition.py index b4c9d52a..796e000d 100644 --- a/tests/torch/test_torch_model_instantiators_definition.py +++ b/tests/torch/test_torch_model_instantiators_definition.py @@ -14,7 +14,7 @@ deterministic_model, gaussian_model, multivariate_gaussian_model, - shared_model + shared_model, ) from skrl.utils.model_instantiators.torch.common import _generate_modules, _get_activation_function, _parse_input @@ -31,6 +31,7 @@ def test_get_activation_function(capsys): assert activation is not None, f"{item} -> None" exec(activation, _globals, {}) + def test_parse_input(capsys): # check for Shape enum (compatibility with prior versions) for input in [Shape.STATES, Shape.OBSERVATIONS, Shape.ACTIONS, Shape.STATES_ACTIONS]: @@ -52,6 +53,7 @@ def test_parse_input(capsys): output = _parse_input(str(input)) assert output.replace("'", '"') == statement, f"'{output}' != '{statement}'" + def test_generate_modules(capsys): _globals = {"nn": torch.nn} @@ -147,6 +149,7 @@ def test_generate_modules(capsys): assert isinstance(container, torch.nn.Sequential) assert len(container) == 2 + def test_gaussian_model(capsys): device = "cpu" observation_space = gym.spaces.Box(np.array([-1] * 5), np.array([1] * 5)) @@ -170,19 +173,15 @@ def test_gaussian_model(capsys): """ content = yaml.safe_load(content) # source - model = gaussian_model(observation_space=observation_space, - action_space=action_space, - device=device, - return_source=True, - **content) + model = gaussian_model( + observation_space=observation_space, action_space=action_space, device=device, return_source=True, **content + ) with capsys.disabled(): print(model) # instance - model = gaussian_model(observation_space=observation_space, - action_space=action_space, - device=device, - return_source=False, - **content) + model = gaussian_model( + observation_space=observation_space, action_space=action_space, device=device, return_source=False, **content + ) model.to(device=device) with capsys.disabled(): print(model) @@ -191,6 +190,7 @@ def test_gaussian_model(capsys): output = model.act({"states": observations}) assert output[0].shape == (10, 2) + def test_multivariate_gaussian_model(capsys): device = "cpu" observation_space = gym.spaces.Box(np.array([-1] * 5), np.array([1] * 5)) @@ -214,19 +214,15 @@ def test_multivariate_gaussian_model(capsys): """ content = yaml.safe_load(content) # source - model = multivariate_gaussian_model(observation_space=observation_space, - action_space=action_space, - device=device, - return_source=True, - **content) + model = multivariate_gaussian_model( + observation_space=observation_space, action_space=action_space, device=device, return_source=True, **content + ) with capsys.disabled(): print(model) # instance - model = multivariate_gaussian_model(observation_space=observation_space, - action_space=action_space, - device=device, - return_source=False, - **content) + model = multivariate_gaussian_model( + observation_space=observation_space, action_space=action_space, device=device, return_source=False, **content + ) model.to(device=device) with capsys.disabled(): print(model) @@ -235,6 +231,7 @@ def test_multivariate_gaussian_model(capsys): output = model.act({"states": observations}) assert output[0].shape == (10, 2) + def test_deterministic_model(capsys): device = "cpu" observation_space = gym.spaces.Box(np.array([-1] * 5), np.array([1] * 5)) @@ -255,19 +252,15 @@ def test_deterministic_model(capsys): """ content = yaml.safe_load(content) # source - model = deterministic_model(observation_space=observation_space, - action_space=action_space, - device=device, - return_source=True, - **content) + model = deterministic_model( + observation_space=observation_space, action_space=action_space, device=device, return_source=True, **content + ) with capsys.disabled(): print(model) # instance - model = deterministic_model(observation_space=observation_space, - action_space=action_space, - device=device, - return_source=False, - **content) + model = deterministic_model( + observation_space=observation_space, action_space=action_space, device=device, return_source=False, **content + ) model.to(device=device) with capsys.disabled(): print(model) @@ -276,6 +269,7 @@ def test_deterministic_model(capsys): output = model.act({"states": observations}) assert output[0].shape == (10, 3) + def test_categorical_model(capsys): device = "cpu" observation_space = gym.spaces.Box(np.array([-1] * 5), np.array([1] * 5)) @@ -295,19 +289,15 @@ def test_categorical_model(capsys): """ content = yaml.safe_load(content) # source - model = categorical_model(observation_space=observation_space, - action_space=action_space, - device=device, - return_source=True, - **content) + model = categorical_model( + observation_space=observation_space, action_space=action_space, device=device, return_source=True, **content + ) with capsys.disabled(): print(model) # instance - model = categorical_model(observation_space=observation_space, - action_space=action_space, - device=device, - return_source=False, - **content) + model = categorical_model( + observation_space=observation_space, action_space=action_space, device=device, return_source=False, **content + ) model.to(device=device) with capsys.disabled(): print(model) @@ -316,6 +306,7 @@ def test_categorical_model(capsys): output = model.act({"states": observations}) assert output[0].shape == (10, 1) + def test_shared_model(capsys): device = "cpu" observation_space = gym.spaces.Box(np.array([-1] * 5), np.array([1] * 5)) @@ -352,27 +343,31 @@ def test_shared_model(capsys): content_policy = yaml.safe_load(content_policy) content_value = yaml.safe_load(content_value) # source - model = shared_model(observation_space=observation_space, - action_space=action_space, - device=device, - roles=["policy", "value"], - parameters=[ - content_policy, - content_value, - ], - return_source=True) + model = shared_model( + observation_space=observation_space, + action_space=action_space, + device=device, + roles=["policy", "value"], + parameters=[ + content_policy, + content_value, + ], + return_source=True, + ) with capsys.disabled(): print(model) # instance - model = shared_model(observation_space=observation_space, - action_space=action_space, - device=device, - roles=["policy", "value"], - parameters=[ - content_policy, - content_value, - ], - return_source=False) + model = shared_model( + observation_space=observation_space, + action_space=action_space, + device=device, + roles=["policy", "value"], + parameters=[ + content_policy, + content_value, + ], + return_source=False, + ) model.to(device=device) with capsys.disabled(): print(model) diff --git a/tests/torch/test_torch_utils_spaces.py b/tests/torch/test_torch_utils_spaces.py index aa3a08ac..bc9c8126 100644 --- a/tests/torch/test_torch_utils_spaces.py +++ b/tests/torch/test_torch_utils_spaces.py @@ -14,7 +14,7 @@ sample_space, tensorize_space, unflatten_tensorized_space, - untensorize_space + untensorize_space, ) from ..stategies import gym_space_stategy, gymnasium_space_stategy @@ -28,6 +28,7 @@ def _check_backend(x, backend): else: raise ValueError(f"Invalid backend type: {backend}") + def check_sampled_space(space, x, n, backend): if isinstance(space, gymnasium.spaces.Box): _check_backend(x, backend) @@ -65,6 +66,7 @@ def occupied_size(s): space_size = compute_space_size(space, occupied_size=True) assert space_size == occupied_size(space) + @hypothesis.given(space=gymnasium_space_stategy()) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_tensorize_space(capsys, space: gymnasium.spaces.Space): @@ -96,6 +98,7 @@ def check_tensorized_space(s, x, n): tensorized_space = tensorize_space(space, sampled_space) check_tensorized_space(space, tensorized_space, 5) + @hypothesis.given(space=gymnasium_space_stategy()) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_untensorize_space(capsys, space: gymnasium.spaces.Space): @@ -107,7 +110,9 @@ def check_untensorized_space(s, x, squeeze_batch_dimension): assert isinstance(x, (np.ndarray, int)) assert isinstance(x, int) if squeeze_batch_dimension else x.shape == (1, 1) elif isinstance(s, gymnasium.spaces.MultiDiscrete): - assert isinstance(x, np.ndarray) and x.shape == s.nvec.shape if squeeze_batch_dimension else (1, *s.nvec.shape) + assert ( + isinstance(x, np.ndarray) and x.shape == s.nvec.shape if squeeze_batch_dimension else (1, *s.nvec.shape) + ) elif isinstance(s, gymnasium.spaces.Dict): list(map(check_untensorized_space, s.values(), x.values(), [squeeze_batch_dimension] * len(s))) elif isinstance(s, gymnasium.spaces.Tuple): @@ -123,6 +128,7 @@ def check_untensorized_space(s, x, squeeze_batch_dimension): untensorized_space = untensorize_space(space, tensorized_space, squeeze_batch_dimension=True) check_untensorized_space(space, untensorized_space, squeeze_batch_dimension=True) + @hypothesis.given(space=gymnasium_space_stategy(), batch_size=st.integers(min_value=1, max_value=10)) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_sample_space(capsys, space: gymnasium.spaces.Space, batch_size: int): @@ -133,6 +139,7 @@ def test_sample_space(capsys, space: gymnasium.spaces.Space, batch_size: int): sampled_space = sample_space(space, batch_size, backend="torch") check_sampled_space(space, sampled_space, batch_size, backend="torch") + @hypothesis.given(space=gymnasium_space_stategy()) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_flatten_tensorized_space(capsys, space: gymnasium.spaces.Space): @@ -146,6 +153,7 @@ def test_flatten_tensorized_space(capsys, space: gymnasium.spaces.Space): flattened_space = flatten_tensorized_space(tensorized_space) assert flattened_space.shape == (5, space_size) + @hypothesis.given(space=gymnasium_space_stategy()) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_unflatten_tensorized_space(capsys, space: gymnasium.spaces.Space): @@ -159,6 +167,7 @@ def test_unflatten_tensorized_space(capsys, space: gymnasium.spaces.Space): unflattened_space = unflatten_tensorized_space(space, flattened_space) check_sampled_space(space, unflattened_space, 5, backend="torch") + @hypothesis.given(space=gym_space_stategy()) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) def test_convert_gym_space(capsys, space: gym.spaces.Space): diff --git a/tests/torch/test_torch_wrapper_deepmind.py b/tests/torch/test_torch_wrapper_deepmind.py index 57c46ca4..4d880e52 100644 --- a/tests/torch/test_torch_wrapper_deepmind.py +++ b/tests/torch/test_torch_wrapper_deepmind.py @@ -28,7 +28,10 @@ def test_env(capsys: pytest.CaptureFixture): # check properties assert env.state_space is None - assert isinstance(env.observation_space, gym.Space) and sorted(list(env.observation_space.keys())) == ["orientation", "velocity"] + assert isinstance(env.observation_space, gym.Space) and sorted(list(env.observation_space.keys())) == [ + "orientation", + "velocity", + ] assert isinstance(env.action_space, gym.Space) and env.action_space.shape == (1,) assert isinstance(env.num_envs, int) and env.num_envs == num_envs assert isinstance(env.num_agents, int) and env.num_agents == 1 diff --git a/tests/torch/test_torch_wrapper_gym.py b/tests/torch/test_torch_wrapper_gym.py index cfee7672..440688b2 100644 --- a/tests/torch/test_torch_wrapper_gym.py +++ b/tests/torch/test_torch_wrapper_gym.py @@ -46,6 +46,7 @@ def test_env(capsys: pytest.CaptureFixture): env.close() + @pytest.mark.parametrize("vectorization_mode", ["async", "sync"]) def test_vectorized_env(capsys: pytest.CaptureFixture, vectorization_mode: str): num_envs = 10 diff --git a/tests/torch/test_torch_wrapper_gymnasium.py b/tests/torch/test_torch_wrapper_gymnasium.py index f603801e..80bbc154 100644 --- a/tests/torch/test_torch_wrapper_gymnasium.py +++ b/tests/torch/test_torch_wrapper_gymnasium.py @@ -45,6 +45,7 @@ def test_env(capsys: pytest.CaptureFixture): env.close() + @pytest.mark.parametrize("vectorization_mode", ["async", "sync"]) def test_vectorized_env(capsys: pytest.CaptureFixture, vectorization_mode: str): num_envs = 10 diff --git a/tests/torch/test_torch_wrapper_isaacgym.py b/tests/torch/test_torch_wrapper_isaacgym.py index e7d2b367..c8614270 100644 --- a/tests/torch/test_torch_wrapper_isaacgym.py +++ b/tests/torch/test_torch_wrapper_isaacgym.py @@ -27,7 +27,7 @@ def __init__(self, num_states) -> None: self.state_space = gym.spaces.Box(np.ones(self.num_states) * -np.Inf, np.ones(self.num_states) * np.Inf) self.observation_space = gym.spaces.Box(np.ones(self.num_obs) * -np.Inf, np.ones(self.num_obs) * np.Inf) - self.action_space = gym.spaces.Box(np.ones(self.num_actions) * -1., np.ones(self.num_actions) * 1.) + self.action_space = gym.spaces.Box(np.ones(self.num_actions) * -1.0, np.ones(self.num_actions) * 1.0) def reset(self) -> Dict[str, torch.Tensor]: obs_dict = {} diff --git a/tests/torch/test_torch_wrapper_isaaclab.py b/tests/torch/test_torch_wrapper_isaaclab.py index 66fffcbe..dca49e39 100644 --- a/tests/torch/test_torch_wrapper_isaaclab.py +++ b/tests/torch/test_torch_wrapper_isaaclab.py @@ -23,6 +23,7 @@ Dict[AgentID, dict], ] + class IsaacLabEnv(gym.Env): def __init__(self, num_states) -> None: self.num_actions = 1 @@ -58,7 +59,9 @@ def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) def step(self, action: torch.Tensor) -> VecEnvStepReturn: assert action.clone().shape == torch.Size([self.num_envs, 1]) - observations = {"policy": torch.ones((self.num_envs, self.num_observations), device=self.device, dtype=torch.float32)} + observations = { + "policy": torch.ones((self.num_envs, self.num_observations), device=self.device, dtype=torch.float32) + } rewards = torch.zeros(self.num_envs, device=self.device, dtype=torch.float32) terminated = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) truncated = torch.zeros_like(terminated) @@ -99,9 +102,7 @@ def _configure_env_spaces(self): if not self.num_states: self.state_space = None if self.num_states < 0: - self.state_space = gym.spaces.Box( - low=-np.inf, high=np.inf, shape=(sum(self.num_observations.values()),) - ) + self.state_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(sum(self.num_observations.values()),)) else: self.state_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.num_states,)) @@ -109,16 +110,28 @@ def _configure_env_spaces(self): def unwrapped(self): return self - def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[Dict[AgentID, ObsType], dict]: - observations = {agent: torch.ones((self.num_envs, self.num_observations[agent]), device=self.device) for agent in self.possible_agents} + def reset( + self, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[Dict[AgentID, ObsType], dict]: + observations = { + agent: torch.ones((self.num_envs, self.num_observations[agent]), device=self.device) + for agent in self.possible_agents + } return observations, self.extras def step(self, action: Dict[AgentID, torch.Tensor]) -> EnvStepReturn: for agent in self.possible_agents: assert action[agent].clone().shape == torch.Size([self.num_envs, self.num_actions[agent]]) - observations = {agent: torch.ones((self.num_envs, self.num_observations[agent]), device=self.device) for agent in self.possible_agents} - rewards = {agent: torch.zeros(self.num_envs, device=self.device, dtype=torch.float32) for agent in self.possible_agents} - terminated = {agent: torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) for agent in self.possible_agents} + observations = { + agent: torch.ones((self.num_envs, self.num_observations[agent]), device=self.device) + for agent in self.possible_agents + } + rewards = { + agent: torch.zeros(self.num_envs, device=self.device, dtype=torch.float32) for agent in self.possible_agents + } + terminated = { + agent: torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) for agent in self.possible_agents + } truncated = {agent: torch.zeros_like(terminated[agent]) for agent in self.possible_agents} return observations, rewards, terminated, truncated, self.extras @@ -176,6 +189,7 @@ def test_env(capsys: pytest.CaptureFixture, num_states: int): env.close() + @pytest.mark.parametrize("num_states", [0, 5]) def test_multi_agent_env(capsys: pytest.CaptureFixture, num_states: int): num_envs = 10 @@ -212,7 +226,9 @@ def test_multi_agent_env(capsys: pytest.CaptureFixture, num_states: int): assert isinstance(observation, Mapping) assert isinstance(info, Mapping) for i, agent in enumerate(possible_agents): - assert isinstance(observation[agent], torch.Tensor) and observation[agent].shape == torch.Size([num_envs, i + 20]) + assert isinstance(observation[agent], torch.Tensor) and observation[agent].shape == torch.Size( + [num_envs, i + 20] + ) for _ in range(3): observation, reward, terminated, truncated, info = env.step(action) state = env.state() @@ -223,10 +239,16 @@ def test_multi_agent_env(capsys: pytest.CaptureFixture, num_states: int): assert isinstance(truncated, Mapping) assert isinstance(info, Mapping) for i, agent in enumerate(possible_agents): - assert isinstance(observation[agent], torch.Tensor) and observation[agent].shape == torch.Size([num_envs, i + 20]) + assert isinstance(observation[agent], torch.Tensor) and observation[agent].shape == torch.Size( + [num_envs, i + 20] + ) assert isinstance(reward[agent], torch.Tensor) and reward[agent].shape == torch.Size([num_envs, 1]) - assert isinstance(terminated[agent], torch.Tensor) and terminated[agent].shape == torch.Size([num_envs, 1]) - assert isinstance(truncated[agent], torch.Tensor) and truncated[agent].shape == torch.Size([num_envs, 1]) + assert isinstance(terminated[agent], torch.Tensor) and terminated[agent].shape == torch.Size( + [num_envs, 1] + ) + assert isinstance(truncated[agent], torch.Tensor) and truncated[agent].shape == torch.Size( + [num_envs, 1] + ) if num_states: assert isinstance(state, torch.Tensor) and state.shape == torch.Size([num_envs, num_states]) else: diff --git a/tests/torch/test_torch_wrapper_omniisaacgym.py b/tests/torch/test_torch_wrapper_omniisaacgym.py index 8d0e899e..5525e947 100644 --- a/tests/torch/test_torch_wrapper_omniisaacgym.py +++ b/tests/torch/test_torch_wrapper_omniisaacgym.py @@ -25,9 +25,16 @@ def __init__(self, num_states) -> None: self.device = "cpu" # initialize data spaces (defaults to gym.Box) - self.action_space = gym.spaces.Box(np.ones(self.num_actions, dtype=np.float32) * -1.0, np.ones(self.num_actions, dtype=np.float32) * 1.0) - self.observation_space = gym.spaces.Box(np.ones(self.num_observations, dtype=np.float32) * -np.Inf, np.ones(self.num_observations, dtype=np.float32) * np.Inf) - self.state_space = gym.spaces.Box(np.ones(self.num_states, dtype=np.float32) * -np.Inf, np.ones(self.num_states, dtype=np.float32) * np.Inf) + self.action_space = gym.spaces.Box( + np.ones(self.num_actions, dtype=np.float32) * -1.0, np.ones(self.num_actions, dtype=np.float32) * 1.0 + ) + self.observation_space = gym.spaces.Box( + np.ones(self.num_observations, dtype=np.float32) * -np.Inf, + np.ones(self.num_observations, dtype=np.float32) * np.Inf, + ) + self.state_space = gym.spaces.Box( + np.ones(self.num_states, dtype=np.float32) * -np.Inf, np.ones(self.num_states, dtype=np.float32) * np.Inf + ) def reset(self): observations = {"obs": torch.ones((self.num_envs, self.num_observations), device=self.device)} @@ -35,7 +42,9 @@ def reset(self): def step(self, actions): assert actions.clone().shape == torch.Size([self.num_envs, 1]) - observations = {"obs": torch.ones((self.num_envs, self.num_observations), device=self.device, dtype=torch.float32)} + observations = { + "obs": torch.ones((self.num_envs, self.num_observations), device=self.device, dtype=torch.float32) + } rewards = torch.zeros(self.num_envs, device=self.device, dtype=torch.float32) dones = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) return observations, rewards, dones, self.extras diff --git a/tests/torch/test_torch_wrapper_pettingzoo.py b/tests/torch/test_torch_wrapper_pettingzoo.py index d6d1d860..ebe0531e 100644 --- a/tests/torch/test_torch_wrapper_pettingzoo.py +++ b/tests/torch/test_torch_wrapper_pettingzoo.py @@ -31,7 +31,11 @@ def test_env(capsys: pytest.CaptureFixture): assert isinstance(env.action_spaces, Mapping) and len(env.action_spaces) == num_agents for agent in possible_agents: assert isinstance(env.state_space(agent), gym.Space) and env.state_space(agent).shape == (560, 880, 3) - assert isinstance(env.observation_space(agent), gym.Space) and env.observation_space(agent).shape == (457, 120, 3) + assert isinstance(env.observation_space(agent), gym.Space) and env.observation_space(agent).shape == ( + 457, + 120, + 3, + ) assert isinstance(env.action_space(agent), gym.Space) and env.action_space(agent).shape == (1,) assert isinstance(env.possible_agents, list) and sorted(env.possible_agents) == sorted(possible_agents) assert isinstance(env.num_envs, int) and env.num_envs == num_envs @@ -47,7 +51,9 @@ def test_env(capsys: pytest.CaptureFixture): assert isinstance(observation, Mapping) assert isinstance(info, Mapping) for agent in possible_agents: - assert isinstance(observation[agent], torch.Tensor) and observation[agent].shape == torch.Size([num_envs, math.prod((457, 120, 3))]) + assert isinstance(observation[agent], torch.Tensor) and observation[agent].shape == torch.Size( + [num_envs, math.prod((457, 120, 3))] + ) for _ in range(3): observation, reward, terminated, truncated, info = env.step(action) state = env.state() @@ -58,10 +64,16 @@ def test_env(capsys: pytest.CaptureFixture): assert isinstance(truncated, Mapping) assert isinstance(info, Mapping) for agent in possible_agents: - assert isinstance(observation[agent], torch.Tensor) and observation[agent].shape == torch.Size([num_envs, math.prod((457, 120, 3))]) + assert isinstance(observation[agent], torch.Tensor) and observation[agent].shape == torch.Size( + [num_envs, math.prod((457, 120, 3))] + ) assert isinstance(reward[agent], torch.Tensor) and reward[agent].shape == torch.Size([num_envs, 1]) - assert isinstance(terminated[agent], torch.Tensor) and terminated[agent].shape == torch.Size([num_envs, 1]) - assert isinstance(truncated[agent], torch.Tensor) and truncated[agent].shape == torch.Size([num_envs, 1]) + assert isinstance(terminated[agent], torch.Tensor) and terminated[agent].shape == torch.Size( + [num_envs, 1] + ) + assert isinstance(truncated[agent], torch.Tensor) and truncated[agent].shape == torch.Size( + [num_envs, 1] + ) assert isinstance(state, torch.Tensor) and state.shape == torch.Size([num_envs, math.prod((560, 880, 3))]) env.close() diff --git a/tests/utils.py b/tests/utils.py index e6df75f0..9c3b6ef2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,7 +5,7 @@ class DummyEnv(gym.Env): - def __init__(self, num_envs, device = "cpu"): + def __init__(self, num_envs, device="cpu"): self.num_agents = 1 self.num_envs = num_envs self.device = torch.device(device) @@ -46,7 +46,9 @@ class _DummyBaseAgent: def __init__(self): pass - def record_transition(self, states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps): + def record_transition( + self, states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ): pass def pre_interaction(self, timestep, timesteps): @@ -69,7 +71,9 @@ def init(self, trainer_cfg=None): def act(self, states, timestep, timesteps): return torch.tensor([]), None, {} - def record_transition(self, states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps): + def record_transition( + self, states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps + ): pass def pre_interaction(self, timestep, timesteps): From aaee05c39c2cda4158a4d3fde03e34894bf9c18c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 4 Nov 2024 17:09:46 -0500 Subject: [PATCH 7/8] Apply codespell --- .pre-commit-config.yaml | 2 +- docs/source/api/agents/sarsa.rst | 2 +- .../franka_emika_panda/reaching_franka_real_env.py | 2 +- .../real_world/kuka_lbr_iiwa/reaching_iiwa_real_env.py | 2 +- .../real_world/kuka_lbr_iiwa/reaching_iiwa_real_ros2_env.py | 6 +++--- .../real_world/kuka_lbr_iiwa/reaching_iiwa_real_ros_env.py | 6 +++--- docs/source/examples/utils/tensorboard_file_iterator.py | 2 +- skrl/__init__.py | 2 +- skrl/agents/jax/a2c/a2c.py | 2 +- skrl/agents/jax/ppo/ppo.py | 2 +- skrl/agents/jax/rpo/rpo.py | 2 +- skrl/agents/torch/a2c/a2c.py | 2 +- skrl/agents/torch/a2c/a2c_rnn.py | 2 +- skrl/agents/torch/amp/amp.py | 2 +- skrl/agents/torch/ppo/ppo.py | 2 +- skrl/agents/torch/ppo/ppo_rnn.py | 2 +- skrl/agents/torch/rpo/rpo.py | 2 +- skrl/agents/torch/rpo/rpo_rnn.py | 2 +- skrl/agents/torch/sarsa/sarsa.py | 2 +- skrl/agents/torch/trpo/trpo.py | 2 +- skrl/agents/torch/trpo/trpo_rnn.py | 2 +- skrl/envs/wrappers/jax/bidexhands_envs.py | 4 ++-- skrl/multi_agents/jax/ippo/ippo.py | 2 +- skrl/multi_agents/jax/mappo/mappo.py | 2 +- skrl/multi_agents/torch/ippo/ippo.py | 2 +- skrl/multi_agents/torch/mappo/mappo.py | 2 +- skrl/resources/preprocessors/jax/running_standard_scaler.py | 4 ++-- tests/jax/test_jax_utils_spaces.py | 2 +- tests/{stategies.py => strategies.py} | 0 tests/test_jax_memories_memory.py | 2 +- tests/torch/test_torch_utils_spaces.py | 2 +- 31 files changed, 36 insertions(+), 36 deletions(-) rename tests/{stategies.py => strategies.py} (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bd2fa15d..9ffd8cd0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: - id: end-of-file-fixer - id: name-tests-test args: ["--pytest-test-first"] - exclude: ^(tests/stategies.py|tests/utils.py) + exclude: ^(tests/strategies.py|tests/utils.py) - id: no-commit-to-branch - id: trailing-whitespace - repo: https://github.com/codespell-project/codespell diff --git a/docs/source/api/agents/sarsa.rst b/docs/source/api/agents/sarsa.rst index e7759b54..6b37c4d6 100644 --- a/docs/source/api/agents/sarsa.rst +++ b/docs/source/api/agents/sarsa.rst @@ -3,7 +3,7 @@ State Action Reward State Action (SARSA) SARSA is a **model-free** **on-policy** algorithm that uses a **tabular** Q-function to handle **discrete** observations and action spaces -Paper: `On-Line Q-Learning Using Connectionist Systems `_ +Paper: `On-Line Q-Learning Using Connectionist Systems `_ .. raw:: html diff --git a/docs/source/examples/real_world/franka_emika_panda/reaching_franka_real_env.py b/docs/source/examples/real_world/franka_emika_panda/reaching_franka_real_env.py index 06c0d20c..2f6cff97 100644 --- a/docs/source/examples/real_world/franka_emika_panda/reaching_franka_real_env.py +++ b/docs/source/examples/real_world/franka_emika_panda/reaching_franka_real_env.py @@ -141,7 +141,7 @@ def _get_observation_reward_done(self): return self.obs_buf, reward, done def reset(self): - print("Reseting...") + print("Resetting...") # end current motion if self.motion is not None: diff --git a/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_real_env.py b/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_real_env.py index 4e5a8e9e..c5bf2352 100644 --- a/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_real_env.py +++ b/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_real_env.py @@ -84,7 +84,7 @@ def _get_observation_reward_done(self): return self.obs_buf, reward, done def reset(self): - print("Reseting...") + print("Resetting...") # go to 1) safe position, 2) random position self.robot.command_joint_position(self.robot_default_dof_pos) diff --git a/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_real_ros2_env.py b/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_real_ros2_env.py index e3598e30..8d20a953 100644 --- a/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_real_ros2_env.py +++ b/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_real_ros2_env.py @@ -113,8 +113,8 @@ def _callback_joint_states(self, msg): self.robot_state["joint_velocity"] = np.array(msg.velocity) def _callback_end_effector_pose(self, msg): - positon = msg.position - self.robot_state["cartesian_position"] = np.array([positon.x, positon.y, positon.z]) + position = msg.position + self.robot_state["cartesian_position"] = np.array([position.x, position.y, position.z]) def _get_observation_reward_done(self): # observation @@ -146,7 +146,7 @@ def _get_observation_reward_done(self): return self.obs_buf, reward, done def reset(self): - print("Reseting...") + print("Resetting...") # go to 1) safe position, 2) random position msg = sensor_msgs.msg.JointState() diff --git a/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_real_ros_env.py b/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_real_ros_env.py index a6df08e6..90ad79f2 100644 --- a/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_real_ros_env.py +++ b/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_real_ros_env.py @@ -90,8 +90,8 @@ def _callback_joint_states(self, msg): self.robot_state["joint_velocity"] = np.array(msg.velocity) def _callback_end_effector_pose(self, msg): - positon = msg.position - self.robot_state["cartesian_position"] = np.array([positon.x, positon.y, positon.z]) + position = msg.position + self.robot_state["cartesian_position"] = np.array([position.x, position.y, position.z]) def _get_observation_reward_done(self): # observation @@ -123,7 +123,7 @@ def _get_observation_reward_done(self): return self.obs_buf, reward, done def reset(self): - print("Reseting...") + print("Resetting...") # go to 1) safe position, 2) random position msg = sensor_msgs.msg.JointState() diff --git a/docs/source/examples/utils/tensorboard_file_iterator.py b/docs/source/examples/utils/tensorboard_file_iterator.py index 01ffb04c..357e0280 100644 --- a/docs/source/examples/utils/tensorboard_file_iterator.py +++ b/docs/source/examples/utils/tensorboard_file_iterator.py @@ -19,7 +19,7 @@ mean = np.mean(rewards[:,:,1], axis=0) std = np.std(rewards[:,:,1], axis=0) -# creae two subplots (one for each reward and one for the mean) +# create two subplots (one for each reward and one for the mean) fig, ax = plt.subplots(1, 2, figsize=(15, 5)) # plot the rewards for each experiment diff --git a/skrl/__init__.py b/skrl/__init__.py index a41d7944..ff811238 100644 --- a/skrl/__init__.py +++ b/skrl/__init__.py @@ -160,7 +160,7 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device": This function supports the PyTorch-like ``"type:ordinal"`` string specification (e.g.: ``"cuda:0"``). - :param device: Device specification. If the specified device is ``None`` ot it cannot be resolved, + :param device: Device specification. If the specified device is ``None`` or it cannot be resolved, the default available device will be returned instead. :return: JAX Device. diff --git a/skrl/agents/jax/a2c/a2c.py b/skrl/agents/jax/a2c/a2c.py index ec3dd53c..2a69f615 100644 --- a/skrl/agents/jax/a2c/a2c.py +++ b/skrl/agents/jax/a2c/a2c.py @@ -413,7 +413,7 @@ def record_transition( values = jax.device_get(values) values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: rewards += self._discount_factor * values * truncated diff --git a/skrl/agents/jax/ppo/ppo.py b/skrl/agents/jax/ppo/ppo.py index fd5858c1..ce9500db 100644 --- a/skrl/agents/jax/ppo/ppo.py +++ b/skrl/agents/jax/ppo/ppo.py @@ -444,7 +444,7 @@ def record_transition( values = jax.device_get(values) values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: rewards += self._discount_factor * values * truncated diff --git a/skrl/agents/jax/rpo/rpo.py b/skrl/agents/jax/rpo/rpo.py index d61f80a9..35bf9e4e 100644 --- a/skrl/agents/jax/rpo/rpo.py +++ b/skrl/agents/jax/rpo/rpo.py @@ -452,7 +452,7 @@ def record_transition( values = jax.device_get(values) values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: rewards += self._discount_factor * values * truncated diff --git a/skrl/agents/torch/a2c/a2c.py b/skrl/agents/torch/a2c/a2c.py index cb08cca8..588767d4 100644 --- a/skrl/agents/torch/a2c/a2c.py +++ b/skrl/agents/torch/a2c/a2c.py @@ -264,7 +264,7 @@ def record_transition( values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value") values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: rewards += self._discount_factor * values * truncated diff --git a/skrl/agents/torch/a2c/a2c_rnn.py b/skrl/agents/torch/a2c/a2c_rnn.py index dd93e0dd..b241baf4 100644 --- a/skrl/agents/torch/a2c/a2c_rnn.py +++ b/skrl/agents/torch/a2c/a2c_rnn.py @@ -305,7 +305,7 @@ def record_transition( values, _, outputs = self.value.act({"states": self._state_preprocessor(states), **rnn}, role="value") values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: rewards += self._discount_factor * values * truncated diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py index 45e40659..de4ffda4 100644 --- a/skrl/agents/torch/amp/amp.py +++ b/skrl/agents/torch/amp/amp.py @@ -365,7 +365,7 @@ def record_transition( values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value") values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: rewards += self._discount_factor * values * truncated diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py index e6be90c4..7c15a9ce 100644 --- a/skrl/agents/torch/ppo/ppo.py +++ b/skrl/agents/torch/ppo/ppo.py @@ -289,7 +289,7 @@ def record_transition( values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value") values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: rewards += self._discount_factor * values * truncated diff --git a/skrl/agents/torch/ppo/ppo_rnn.py b/skrl/agents/torch/ppo/ppo_rnn.py index 7259507c..9e22be4e 100644 --- a/skrl/agents/torch/ppo/ppo_rnn.py +++ b/skrl/agents/torch/ppo/ppo_rnn.py @@ -320,7 +320,7 @@ def record_transition( values, _, outputs = self.value.act({"states": self._state_preprocessor(states), **rnn}, role="value") values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: rewards += self._discount_factor * values * truncated diff --git a/skrl/agents/torch/rpo/rpo.py b/skrl/agents/torch/rpo/rpo.py index 24561945..e91ce293 100644 --- a/skrl/agents/torch/rpo/rpo.py +++ b/skrl/agents/torch/rpo/rpo.py @@ -285,7 +285,7 @@ def record_transition( ) values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: rewards += self._discount_factor * values * truncated diff --git a/skrl/agents/torch/rpo/rpo_rnn.py b/skrl/agents/torch/rpo/rpo_rnn.py index 5cca8648..550534d4 100644 --- a/skrl/agents/torch/rpo/rpo_rnn.py +++ b/skrl/agents/torch/rpo/rpo_rnn.py @@ -326,7 +326,7 @@ def record_transition( ) values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: rewards += self._discount_factor * values * truncated diff --git a/skrl/agents/torch/sarsa/sarsa.py b/skrl/agents/torch/sarsa/sarsa.py index c7c4b7c3..11d75be8 100644 --- a/skrl/agents/torch/sarsa/sarsa.py +++ b/skrl/agents/torch/sarsa/sarsa.py @@ -50,7 +50,7 @@ def __init__( ) -> None: """State Action Reward State Action (SARSA) - https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.17.2539 + https://scholar.google.com/scholar?q=On-line+Q-learning+using+connectionist+system :param models: Models used by the agent :type models: dictionary of skrl.models.torch.Model diff --git a/skrl/agents/torch/trpo/trpo.py b/skrl/agents/torch/trpo/trpo.py index 12e3910a..07a4fc64 100644 --- a/skrl/agents/torch/trpo/trpo.py +++ b/skrl/agents/torch/trpo/trpo.py @@ -275,7 +275,7 @@ def record_transition( values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value") values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: rewards += self._discount_factor * values * truncated diff --git a/skrl/agents/torch/trpo/trpo_rnn.py b/skrl/agents/torch/trpo/trpo_rnn.py index ff7804c6..4b41f332 100644 --- a/skrl/agents/torch/trpo/trpo_rnn.py +++ b/skrl/agents/torch/trpo/trpo_rnn.py @@ -316,7 +316,7 @@ def record_transition( values, _, outputs = self.value.act({"states": self._state_preprocessor(states), **rnn}, role="value") values = self._value_preprocessor(values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: rewards += self._discount_factor * values * truncated diff --git a/skrl/envs/wrappers/jax/bidexhands_envs.py b/skrl/envs/wrappers/jax/bidexhands_envs.py index b63549a9..ad6b1a9d 100644 --- a/skrl/envs/wrappers/jax/bidexhands_envs.py +++ b/skrl/envs/wrappers/jax/bidexhands_envs.py @@ -98,10 +98,10 @@ def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> Tuple[ """Perform a step in the environment :param actions: The actions to perform - :type actions: dict of nd.ndarray or jax.Array + :type actions: dict of np.ndarray or jax.Array :return: Observation, reward, terminated, truncated, info - :rtype: tuple of dict of nd.ndarray or jax.Array and any other info + :rtype: tuple of dict of np.ndarray or jax.Array and any other info """ actions = [_jax2torch(actions[uid], self._env.rl_device, self._jax) for uid in self.possible_agents] diff --git a/skrl/multi_agents/jax/ippo/ippo.py b/skrl/multi_agents/jax/ippo/ippo.py index d84b674d..caf48c4e 100644 --- a/skrl/multi_agents/jax/ippo/ippo.py +++ b/skrl/multi_agents/jax/ippo/ippo.py @@ -470,7 +470,7 @@ def record_transition( values = jax.device_get(values) values = self._value_preprocessor[uid](values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap[uid]: rewards[uid] += self._discount_factor[uid] * values * truncated[uid] diff --git a/skrl/multi_agents/jax/mappo/mappo.py b/skrl/multi_agents/jax/mappo/mappo.py index 2d1db277..b46999c9 100644 --- a/skrl/multi_agents/jax/mappo/mappo.py +++ b/skrl/multi_agents/jax/mappo/mappo.py @@ -499,7 +499,7 @@ def record_transition( values = jax.device_get(values) values = self._value_preprocessor[uid](values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap[uid]: rewards[uid] += self._discount_factor[uid] * values * truncated[uid] diff --git a/skrl/multi_agents/torch/ippo/ippo.py b/skrl/multi_agents/torch/ippo/ippo.py index c8296fa8..3331c9e7 100644 --- a/skrl/multi_agents/torch/ippo/ippo.py +++ b/skrl/multi_agents/torch/ippo/ippo.py @@ -303,7 +303,7 @@ def record_transition( ) values = self._value_preprocessor[uid](values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap[uid]: rewards[uid] += self._discount_factor[uid] * values * truncated[uid] diff --git a/skrl/multi_agents/torch/mappo/mappo.py b/skrl/multi_agents/torch/mappo/mappo.py index ae549c7d..90a904ef 100644 --- a/skrl/multi_agents/torch/mappo/mappo.py +++ b/skrl/multi_agents/torch/mappo/mappo.py @@ -332,7 +332,7 @@ def record_transition( ) values = self._value_preprocessor[uid](values, inverse=True) - # time-limit (truncation) boostrapping + # time-limit (truncation) bootstrapping if self._time_limit_bootstrap[uid]: rewards[uid] += self._discount_factor[uid] * values * truncated[uid] diff --git a/skrl/resources/preprocessors/jax/running_standard_scaler.py b/skrl/resources/preprocessors/jax/running_standard_scaler.py index 1de08152..fdebb3c2 100644 --- a/skrl/resources/preprocessors/jax/running_standard_scaler.py +++ b/skrl/resources/preprocessors/jax/running_standard_scaler.py @@ -20,7 +20,7 @@ def _copyto(dst, src): @jax.jit def _parallel_variance( running_mean: jax.Array, running_variance: jax.Array, current_count: jax.Array, array: jax.Array -) -> Tuple[jax.Array, jax.Array, jax.Array]: # yapf: disable +) -> Tuple[jax.Array, jax.Array, jax.Array]: # ddof = 1: https://github.com/pytorch/pytorch/issues/50010 if array.ndim == 3: input_mean = jnp.mean(array, axis=(0, 1)) @@ -45,7 +45,7 @@ def _parallel_variance( @jax.jit def _inverse( running_mean: jax.Array, running_variance: jax.Array, clip_threshold: float, array: jax.Array -) -> jax.Array: # yapf: disable +) -> jax.Array: return jnp.sqrt(running_variance) * jnp.clip(array, -clip_threshold, clip_threshold) + running_mean diff --git a/tests/jax/test_jax_utils_spaces.py b/tests/jax/test_jax_utils_spaces.py index 2a845193..1933fbc9 100644 --- a/tests/jax/test_jax_utils_spaces.py +++ b/tests/jax/test_jax_utils_spaces.py @@ -18,7 +18,7 @@ untensorize_space, ) -from ..stategies import gym_space_stategy, gymnasium_space_stategy +from ..strategies import gym_space_stategy, gymnasium_space_stategy def _check_backend(x, backend): diff --git a/tests/stategies.py b/tests/strategies.py similarity index 100% rename from tests/stategies.py rename to tests/strategies.py diff --git a/tests/test_jax_memories_memory.py b/tests/test_jax_memories_memory.py index 1f0d996f..49e6b2f5 100644 --- a/tests/test_jax_memories_memory.py +++ b/tests/test_jax_memories_memory.py @@ -105,7 +105,7 @@ def test_sample(self): import sys if not sys.argv[-1] == "--debug": - raise RuntimeError("Test can only be runned manually with --debug flag") + raise RuntimeError("Test can only be run manually with --debug flag") test = TestCase() test.setUp() diff --git a/tests/torch/test_torch_utils_spaces.py b/tests/torch/test_torch_utils_spaces.py index bc9c8126..da117412 100644 --- a/tests/torch/test_torch_utils_spaces.py +++ b/tests/torch/test_torch_utils_spaces.py @@ -17,7 +17,7 @@ untensorize_space, ) -from ..stategies import gym_space_stategy, gymnasium_space_stategy +from ..strategies import gym_space_stategy, gymnasium_space_stategy def _check_backend(x, backend): From 30eb1b58e1be94ee0fd0d655a7256e149bd5d93c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 4 Nov 2024 17:19:31 -0500 Subject: [PATCH 8/8] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a9f5d769..210351ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Make flattened tensor storage in memory the default option (revert changed introduced in version 1.3.0) - Drop support for PyTorch versions prior to 1.10 (the previous supported version was 1.9). +### Changed (breaking changes: style) +- Format code using Black code formatter (it's ugly, yes, but it does its job) + ### Fixed - Moved the batch sampling inside gradient step loop for DQN, DDQN, DDPG (RNN), TD3 (RNN), SAC and SAC (RNN)