Skip to content

Commit

Permalink
feat: allow ddpg-based systems to have any action range
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Mar 8, 2024
1 parent 8e9be7e commit ac7ca3f
Show file tree
Hide file tree
Showing 15 changed files with 46 additions and 44 deletions.
2 changes: 1 addition & 1 deletion stoix/configs/default_ff_d4pg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ defaults:
- arch: anakin
- system: ff_d4pg
- network: mlp_d4pg
- env: brax/ant
- env: gymnax/pendulum
- _self_
2 changes: 1 addition & 1 deletion stoix/configs/default_ff_ppo_continuous.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ defaults:
- arch: anakin
- system: ff_ppo
- network: mlp_continuous
- env: brax/ant
- env: gymnax/pendulum
- _self_
4 changes: 2 additions & 2 deletions stoix/configs/network/mlp_c51.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: True
activation: silu
use_layer_norm: False
activation: relu
action_head:
_target_: stoix.networks.heads.DistributionalDiscreteQNetwork
8 changes: 4 additions & 4 deletions stoix/configs/network/mlp_continuous.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: True
activation: silu
use_layer_norm: False
activation: relu
action_head:
_target_: stoix.networks.heads.NormalTanhDistributionHead

critic_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: True
activation: silu
use_layer_norm: False
activation: relu
critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
8 changes: 4 additions & 4 deletions stoix/configs/network/mlp_d4pg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256, 256]
use_layer_norm: True
activation: silu
use_layer_norm: False
activation: relu
action_head:
_target_: stoix.networks.heads.DeterministicHead
post_processor:
Expand All @@ -16,7 +16,7 @@ q_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256, 256]
use_layer_norm: True
activation: silu
use_layer_norm: False
activation: relu
critic_head:
_target_: stoix.networks.heads.DistributionalContinuousQNetwork
12 changes: 6 additions & 6 deletions stoix/configs/network/mlp_ddpg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256, 256]
use_layer_norm: True
activation: silu
layer_sizes: [64, 64]
use_layer_norm: False
activation: relu
action_head:
_target_: stoix.networks.heads.DeterministicHead
post_processor:
Expand All @@ -15,8 +15,8 @@ q_network:
_target_: stoix.networks.inputs.ObservationActionInput
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256, 256]
use_layer_norm: True
activation: silu
layer_sizes: [64, 64]
use_layer_norm: False
activation: relu
critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
4 changes: 2 additions & 2 deletions stoix/configs/network/mlp_dqn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: True
activation: silu
use_layer_norm: False
activation: relu
action_head:
_target_: stoix.networks.heads.DiscreteQNetworkHead
4 changes: 2 additions & 2 deletions stoix/configs/network/mlp_qr_dqn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: True
activation: silu
use_layer_norm: False
activation: relu
action_head:
_target_: stoix.networks.heads.QuantileDiscreteQNetwork
4 changes: 2 additions & 2 deletions stoix/configs/network/mlp_sac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ actor_network:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [32, 32, 32, 32]
use_layer_norm: False
activation: silu
activation: relu
action_head:
_target_: stoix.networks.heads.NormalTanhDistributionHead

Expand All @@ -15,6 +15,6 @@ q_network:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: True
activation: silu
activation: relu
critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
8 changes: 4 additions & 4 deletions stoix/configs/network/rnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ actor_network:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256]
use_layer_norm: False
activation: silu
activation: relu
post_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256]
use_layer_norm: False
activation: silu
activation: relu
action_head:
_target_: stoix.networks.heads.CategoricalHead

Expand All @@ -18,11 +18,11 @@ critic_network:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256]
use_layer_norm: False
activation: silu
activation: relu
post_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256]
use_layer_norm: False
activation: silu
activation: relu
critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
4 changes: 2 additions & 2 deletions stoix/systems/ddpg/ff_d4pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ def learner_setup(
)
action_head_post_processor = hydra.utils.instantiate(
config.network.actor_network.post_processor,
minimum=-1.0,
maximum=1.0,
minimum=env.action_spec().minimum,
maximum=env.action_spec().maximum,
scale_fn=tanh_to_spec,
)
actor_action_head = CompositeNetwork([actor_action_head, action_head_post_processor])
Expand Down
4 changes: 2 additions & 2 deletions stoix/systems/ddpg/ff_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ def learner_setup(
)
action_head_post_processor = hydra.utils.instantiate(
config.network.actor_network.post_processor,
minimum=-1.0,
maximum=1.0,
minimum=env.action_spec().minimum,
maximum=env.action_spec().maximum,
scale_fn=tanh_to_spec,
)
actor_action_head = CompositeNetwork([actor_action_head, action_head_post_processor])
Expand Down
4 changes: 2 additions & 2 deletions stoix/systems/ddpg/ff_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ def learner_setup(
)
action_head_post_processor = hydra.utils.instantiate(
config.network.actor_network.post_processor,
minimum=-1.0,
maximum=1.0,
minimum=env.action_spec().minimum,
maximum=env.action_spec().maximum,
scale_fn=tanh_to_spec,
)
actor_action_head = CompositeNetwork([actor_action_head, action_head_post_processor])
Expand Down
15 changes: 8 additions & 7 deletions stoix/utils/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from jumanji.env import Environment
from jumanji.registration import _REGISTRY as JUMANJI_REGISTRY
from jumanji.specs import BoundedArray, MultiDiscreteArray
from jumanji.wrappers import AutoResetWrapper, MultiToSingleWrapper
from jumanji.wrappers import MultiToSingleWrapper
from omegaconf import DictConfig
from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY

Expand All @@ -24,6 +24,7 @@
from stoix.wrappers.brax import BraxJumanjiWrapper
from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper
from stoix.wrappers.jumanji import MultiBoundedToBounded, MultiDiscreteToDiscrete
from stoix.wrappers.truncation import TruncationAutoResetWrapper
from stoix.wrappers.xminigrid import XMiniGridWrapper


Expand Down Expand Up @@ -60,7 +61,7 @@ def make_jumanji_env(
config.env.multi_agent,
)

env = AutoResetWrapper(env)
env = TruncationAutoResetWrapper(env)
env = RecordEpisodeMetrics(env)

return env, eval_env
Expand All @@ -85,7 +86,7 @@ def make_gymnax_env(env_name: str, config: DictConfig) -> Tuple[Environment, Env
env = GymnaxWrapper(env, env_params)
eval_env = GymnaxWrapper(eval_env, eval_env_params)

env = AutoResetWrapper(env)
env = TruncationAutoResetWrapper(env)
env = RecordEpisodeMetrics(env)

return env, eval_env
Expand All @@ -112,7 +113,7 @@ def make_xland_minigrid_env(env_name: str, config: DictConfig) -> Tuple[Environm
env = XMiniGridWrapper(env, env_params, config.env.flatten_observation)
eval_env = XMiniGridWrapper(eval_env, eval_env_params, config.env.flatten_observation)

env = AutoResetWrapper(env)
env = TruncationAutoResetWrapper(env)
env = RecordEpisodeMetrics(env)

return env, eval_env
Expand Down Expand Up @@ -189,7 +190,7 @@ def make_jaxmarl_env(
else:
raise ValueError(f"Unsupported action spec for JAXMarl {env.action_spec()}.")

env = AutoResetWrapper(env)
env = TruncationAutoResetWrapper(env)
env = RecordEpisodeMetrics(env)

return env, eval_env
Expand Down Expand Up @@ -231,7 +232,7 @@ def make_craftax_env(env_name: str, config: DictConfig) -> Tuple[Environment, En
env = GymnaxWrapper(env, env_params)
eval_env = GymnaxWrapper(eval_env, eval_env_params)

env = AutoResetWrapper(env)
env = TruncationAutoResetWrapper(env)
env = RecordEpisodeMetrics(env)

return env, eval_env
Expand All @@ -252,7 +253,7 @@ def make_debug_env(env_name: str, config: DictConfig) -> Tuple[Environment, Envi
env = IdentityGame(**config.env.kwargs)
eval_env = IdentityGame(**config.env.kwargs)

env = AutoResetWrapper(env)
env = TruncationAutoResetWrapper(env)
env = RecordEpisodeMetrics(env)

return env, eval_env
Expand Down
7 changes: 4 additions & 3 deletions stoix/wrappers/truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@ class TruncationAutoResetWrapper(Wrapper):
the state, observation, and step_type are reset. The observation and step_type of the
terminal TimeStep is reset to the reset observation and StepType.LAST, respectively.
The reward, discount, and extras retrieved from the transition to the terminal state.
NOTE: The observation from the terminal TimeStep is stored in timestep.extras["real_next_obs"].
NOTE: The observation from the terminal TimeStep is stored in
timestep.extras["final_observation"].
WARNING: do not `jax.vmap` the wrapped environment (e.g. do not use with the `VmapWrapper`),
which would lead to inefficient computation due to both the `step` and `reset` functions
being processed each time `step` is called. Please use the `VmapAutoResetWrapper` instead.
"""

OBS_IN_EXTRAS_KEY = "real_next_obs"
OBS_IN_EXTRAS_KEY = "final_observation"

def _obs_in_extras(
self, state: State, timestep: TimeStep[Observation]
) -> Tuple[State, TimeStep[Observation]]:
"""Place the observation in timestep.extras[real_next_obs]."""
"""Place the observation in timestep.extras[final_observation]."""
extras = timestep.extras
extras[TruncationAutoResetWrapper.OBS_IN_EXTRAS_KEY] = timestep.observation
return state, timestep.replace(extras=extras)
Expand Down

0 comments on commit ac7ca3f

Please sign in to comment.