Skip to content

Commit

Permalink
Merge pull request #9 from EdanToledo/feat/mpo
Browse files Browse the repository at this point in the history
Feat/mpo
  • Loading branch information
EdanToledo authored Feb 27, 2024
2 parents a08d567 + f4219de commit 479ff1e
Show file tree
Hide file tree
Showing 19 changed files with 2,358 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ exclude = [
".eggs",
]
max-line-length=100
max-cognitive-complexity=11
max-cognitive-complexity=15
import-order-style = "google"
application-import-names = "stoix"
doctests = true
Expand Down
7 changes: 7 additions & 0 deletions stoix/configs/default_ff_mpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- logger: ff_mpo
- arch: anakin
- system: ff_mpo
- network: mlp_mpo
- env: gymnax/cartpole
- _self_
7 changes: 7 additions & 0 deletions stoix/configs/default_ff_mpo_continuous.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- logger: ff_mpo
- arch: anakin
- system: ff_mpo_continuous
- network: mlp_mpo_continuous
- env: brax/ant
- _self_
5 changes: 4 additions & 1 deletion stoix/configs/env/jumanji/snake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ scenario:
name: Snake-v1
task_name: snake

kwargs: {}
kwargs: {
num_rows: 6,
num_cols: 6,
}
4 changes: 4 additions & 0 deletions stoix/configs/logger/ff_mpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
defaults:
- base_logger

system_name: ff_mpo
8 changes: 4 additions & 4 deletions stoix/configs/network/mlp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: True
activation: silu
use_layer_norm: False
activation: relu
action_head:
_target_: stoix.networks.heads.CategoricalHead

critic_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: True
activation: silu
use_layer_norm: False
activation: relu
critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
20 changes: 20 additions & 0 deletions stoix/configs/network/mlp_mpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# ---MLP MPO Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: False
activation: relu
action_head:
_target_: stoix.networks.heads.CategoricalHead

q_network:
input_layer:
_target_: stoix.networks.inputs.ObservationActionInput
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: False
activation: relu
critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
20 changes: 20 additions & 0 deletions stoix/configs/network/mlp_mpo_continuous.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# ---MLP MPO Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256, 256]
use_layer_norm: False
activation: relu
action_head:
_target_: stoix.networks.heads.MultivariateNormalDiagHead

q_network:
input_layer:
_target_: stoix.networks.inputs.ObservationActionInput
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256, 256]
use_layer_norm: False
activation: relu
critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
34 changes: 34 additions & 0 deletions stoix/configs/system/ff_mpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# --- Defaults FF-MPO ---

total_timesteps: 1e8 # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: ~ # Number of updates
seed: 0

# --- RL hyperparameters ---
update_batch_size: 1 # Number of vectorised gradient updates per device.
rollout_length: 8 # Number of environment steps per vectorised environment.
epochs: 16 # Number of sgd steps per rollout.
warmup_steps: 256 # Number of steps to collect before training.
buffer_size: 1_000_000 # size of the replay buffer.
batch_size: 8 # Number of samples to train on per device.
sample_sequence_length: 8 # Number of steps to consider for each element of the batch.
period : 1 # Period of the sampled sequences.
actor_lr: 1e-4 # the learning rate of the policy network optimizer
q_lr: 1e-4 # the learning rate of the Q network network optimizer
dual_lr: 1e-2 # the learning rate of the alpha optimizer
tau: 0.005 # smoothing coefficient for target networks
gamma: 0.99 # discount factor
max_grad_norm: 40.0 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
max_abs_reward : 20_000 # maximum absolute reward value
num_samples: 20 # Number of MPO action samples for the policy update.
epsilon: 0.01 # KL constraint on the non-parametric auxiliary policy, the one associated with the dual variable called temperature.
epsilon_policy : 1e-3 # KL constraint on the categorical policy, the one associated with the dual variable called alpha.
init_log_temperature: 3. # initial value for the temperature in log-space, note a softplus (rather than an exp) will be used to transform this.
init_log_alpha: 3. # initial value for the alpha value in log-space, note a softplus (rather than an exp) will be used to transform this.
stochastic_policy_eval: True # whether to use a stochastic policy for Q function target evaluation.
use_online_policy_to_bootstrap: False # whether to use the online policy to bootstrap the Q function targets.
use_retrace : False # whether to use the retrace algorithm for off-policy correction.
retrace_lambda : 0.95 # the retrace lambda parameter.
n_step_for_sequence_bootstrap : 5 # the number of steps to use for the sequence bootstrap. This is only used if use_retrace is False.
39 changes: 39 additions & 0 deletions stoix/configs/system/ff_mpo_continuous.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# --- Defaults FF-MPO ---

total_timesteps: 1e8 # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: ~ # Number of updates
seed: 0

# --- RL hyperparameters ---
update_batch_size: 1 # Number of vectorised gradient updates per device.
rollout_length: 8 # Number of environment steps per vectorised environment.
epochs: 32 # Number of sgd steps per rollout.
warmup_steps: 256 # Number of steps to collect before training.
buffer_size: 1_000_000 # size of the replay buffer.
batch_size: 8 # Number of samples to train on per device.
sample_sequence_length: 8 # Number of steps to consider for each element of the batch.
period : 1 # Period of the sampled sequences.
actor_lr: 1e-4 # the learning rate of the policy network optimizer
q_lr: 1e-4 # the learning rate of the Q network network optimizer
dual_lr: 1e-2 # the learning rate of the alpha optimizer
tau: 0.005 # smoothing coefficient for target networks
gamma: 0.99 # discount factor
max_grad_norm: 40.0 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
max_abs_reward : 20_000 # maximum absolute reward value
num_samples: 20 # Number of MPO action samples for the policy update.
epsilon: 0.01 # KL constraint on the non-parametric auxiliary policy, the one associated with the dual variable called temperature.
epsilon_mean : 1e-3 # KL constraint on the mean of the Gaussian policy, the one associated with the dual variable called alpha_mean.
epsilon_stddev: 1e-5 # KL constraint on the stddev of the Gaussian policy, the one associated with the dual variable called alpha_mean.
init_log_temperature: 10. # initial value for the temperature in log-space, note a softplus (rather than an exp) will be used to transform this.
init_log_alpha_mean: 10. # initial value for the alpha_mean in log-space, note a softplus (rather than an exp) will be used to transform this.
init_log_alpha_stddev: 1000. # initial value for the alpha_stddev in log-space, note a softplus (rather than an exp) will be used to transform this.
per_dim_constraining: True # whether to enforce the KL constraint on each dimension independently; this is the default. Otherwise the overall KL is constrained, which allows some dimensions to change more at the expense of others staying put.
action_penalization: True # whether to use a KL constraint to penalize actions via the MO-MPO algorithm.
epsilon_penalty: 0.001 # KL constraint on the probability of violating the action constraint.
stochastic_policy_eval: True # whether to use a stochastic policy for Q function target evaluation.
use_online_policy_to_bootstrap: False # whether to use the online policy to bootstrap the Q function targets.
use_retrace : False # whether to use the retrace algorithm for off-policy correction.
retrace_lambda : 0.95 # the retrace lambda parameter.
n_step_for_sequence_bootstrap : 5 # the number of steps to use for the sequence bootstrap. This is only used if use_retrace is False.
Loading

0 comments on commit 479ff1e

Please sign in to comment.