-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from EdanToledo/feat/mpo
Feat/mpo
- Loading branch information
Showing
19 changed files
with
2,358 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
defaults: | ||
- logger: ff_mpo | ||
- arch: anakin | ||
- system: ff_mpo | ||
- network: mlp_mpo | ||
- env: gymnax/cartpole | ||
- _self_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
defaults: | ||
- logger: ff_mpo | ||
- arch: anakin | ||
- system: ff_mpo_continuous | ||
- network: mlp_mpo_continuous | ||
- env: brax/ant | ||
- _self_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,7 @@ scenario: | |
name: Snake-v1 | ||
task_name: snake | ||
|
||
kwargs: {} | ||
kwargs: { | ||
num_rows: 6, | ||
num_cols: 6, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
defaults: | ||
- base_logger | ||
|
||
system_name: ff_mpo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# ---MLP MPO Networks--- | ||
actor_network: | ||
pre_torso: | ||
_target_: stoix.networks.torso.MLPTorso | ||
layer_sizes: [256, 256] | ||
use_layer_norm: False | ||
activation: relu | ||
action_head: | ||
_target_: stoix.networks.heads.CategoricalHead | ||
|
||
q_network: | ||
input_layer: | ||
_target_: stoix.networks.inputs.ObservationActionInput | ||
pre_torso: | ||
_target_: stoix.networks.torso.MLPTorso | ||
layer_sizes: [256, 256] | ||
use_layer_norm: False | ||
activation: relu | ||
critic_head: | ||
_target_: stoix.networks.heads.ScalarCriticHead |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# ---MLP MPO Networks--- | ||
actor_network: | ||
pre_torso: | ||
_target_: stoix.networks.torso.MLPTorso | ||
layer_sizes: [256, 256, 256] | ||
use_layer_norm: False | ||
activation: relu | ||
action_head: | ||
_target_: stoix.networks.heads.MultivariateNormalDiagHead | ||
|
||
q_network: | ||
input_layer: | ||
_target_: stoix.networks.inputs.ObservationActionInput | ||
pre_torso: | ||
_target_: stoix.networks.torso.MLPTorso | ||
layer_sizes: [256, 256, 256] | ||
use_layer_norm: False | ||
activation: relu | ||
critic_head: | ||
_target_: stoix.networks.heads.ScalarCriticHead |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# --- Defaults FF-MPO --- | ||
|
||
total_timesteps: 1e8 # Set the total environment steps. | ||
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. | ||
num_updates: ~ # Number of updates | ||
seed: 0 | ||
|
||
# --- RL hyperparameters --- | ||
update_batch_size: 1 # Number of vectorised gradient updates per device. | ||
rollout_length: 8 # Number of environment steps per vectorised environment. | ||
epochs: 16 # Number of sgd steps per rollout. | ||
warmup_steps: 256 # Number of steps to collect before training. | ||
buffer_size: 1_000_000 # size of the replay buffer. | ||
batch_size: 8 # Number of samples to train on per device. | ||
sample_sequence_length: 8 # Number of steps to consider for each element of the batch. | ||
period : 1 # Period of the sampled sequences. | ||
actor_lr: 1e-4 # the learning rate of the policy network optimizer | ||
q_lr: 1e-4 # the learning rate of the Q network network optimizer | ||
dual_lr: 1e-2 # the learning rate of the alpha optimizer | ||
tau: 0.005 # smoothing coefficient for target networks | ||
gamma: 0.99 # discount factor | ||
max_grad_norm: 40.0 # Maximum norm of the gradients for a weight update. | ||
decay_learning_rates: False # Whether learning rates should be linearly decayed during training. | ||
max_abs_reward : 20_000 # maximum absolute reward value | ||
num_samples: 20 # Number of MPO action samples for the policy update. | ||
epsilon: 0.01 # KL constraint on the non-parametric auxiliary policy, the one associated with the dual variable called temperature. | ||
epsilon_policy : 1e-3 # KL constraint on the categorical policy, the one associated with the dual variable called alpha. | ||
init_log_temperature: 3. # initial value for the temperature in log-space, note a softplus (rather than an exp) will be used to transform this. | ||
init_log_alpha: 3. # initial value for the alpha value in log-space, note a softplus (rather than an exp) will be used to transform this. | ||
stochastic_policy_eval: True # whether to use a stochastic policy for Q function target evaluation. | ||
use_online_policy_to_bootstrap: False # whether to use the online policy to bootstrap the Q function targets. | ||
use_retrace : False # whether to use the retrace algorithm for off-policy correction. | ||
retrace_lambda : 0.95 # the retrace lambda parameter. | ||
n_step_for_sequence_bootstrap : 5 # the number of steps to use for the sequence bootstrap. This is only used if use_retrace is False. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# --- Defaults FF-MPO --- | ||
|
||
total_timesteps: 1e8 # Set the total environment steps. | ||
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. | ||
num_updates: ~ # Number of updates | ||
seed: 0 | ||
|
||
# --- RL hyperparameters --- | ||
update_batch_size: 1 # Number of vectorised gradient updates per device. | ||
rollout_length: 8 # Number of environment steps per vectorised environment. | ||
epochs: 32 # Number of sgd steps per rollout. | ||
warmup_steps: 256 # Number of steps to collect before training. | ||
buffer_size: 1_000_000 # size of the replay buffer. | ||
batch_size: 8 # Number of samples to train on per device. | ||
sample_sequence_length: 8 # Number of steps to consider for each element of the batch. | ||
period : 1 # Period of the sampled sequences. | ||
actor_lr: 1e-4 # the learning rate of the policy network optimizer | ||
q_lr: 1e-4 # the learning rate of the Q network network optimizer | ||
dual_lr: 1e-2 # the learning rate of the alpha optimizer | ||
tau: 0.005 # smoothing coefficient for target networks | ||
gamma: 0.99 # discount factor | ||
max_grad_norm: 40.0 # Maximum norm of the gradients for a weight update. | ||
decay_learning_rates: False # Whether learning rates should be linearly decayed during training. | ||
max_abs_reward : 20_000 # maximum absolute reward value | ||
num_samples: 20 # Number of MPO action samples for the policy update. | ||
epsilon: 0.01 # KL constraint on the non-parametric auxiliary policy, the one associated with the dual variable called temperature. | ||
epsilon_mean : 1e-3 # KL constraint on the mean of the Gaussian policy, the one associated with the dual variable called alpha_mean. | ||
epsilon_stddev: 1e-5 # KL constraint on the stddev of the Gaussian policy, the one associated with the dual variable called alpha_mean. | ||
init_log_temperature: 10. # initial value for the temperature in log-space, note a softplus (rather than an exp) will be used to transform this. | ||
init_log_alpha_mean: 10. # initial value for the alpha_mean in log-space, note a softplus (rather than an exp) will be used to transform this. | ||
init_log_alpha_stddev: 1000. # initial value for the alpha_stddev in log-space, note a softplus (rather than an exp) will be used to transform this. | ||
per_dim_constraining: True # whether to enforce the KL constraint on each dimension independently; this is the default. Otherwise the overall KL is constrained, which allows some dimensions to change more at the expense of others staying put. | ||
action_penalization: True # whether to use a KL constraint to penalize actions via the MO-MPO algorithm. | ||
epsilon_penalty: 0.001 # KL constraint on the probability of violating the action constraint. | ||
stochastic_policy_eval: True # whether to use a stochastic policy for Q function target evaluation. | ||
use_online_policy_to_bootstrap: False # whether to use the online policy to bootstrap the Q function targets. | ||
use_retrace : False # whether to use the retrace algorithm for off-policy correction. | ||
retrace_lambda : 0.95 # the retrace lambda parameter. | ||
n_step_for_sequence_bootstrap : 5 # the number of steps to use for the sequence bootstrap. This is only used if use_retrace is False. |
Oops, something went wrong.