diff --git a/pyproject.toml b/pyproject.toml index f07da107..42afbf4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ exclude = [ ".eggs", ] max-line-length=100 -max-cognitive-complexity=11 +max-cognitive-complexity=15 import-order-style = "google" application-import-names = "stoix" doctests = true diff --git a/stoix/configs/default_ff_mpo.yaml b/stoix/configs/default_ff_mpo.yaml new file mode 100644 index 00000000..3aa09255 --- /dev/null +++ b/stoix/configs/default_ff_mpo.yaml @@ -0,0 +1,7 @@ +defaults: + - logger: ff_mpo + - arch: anakin + - system: ff_mpo + - network: mlp_mpo + - env: gymnax/cartpole + - _self_ diff --git a/stoix/configs/default_ff_mpo_continuous.yaml b/stoix/configs/default_ff_mpo_continuous.yaml new file mode 100644 index 00000000..4d52d595 --- /dev/null +++ b/stoix/configs/default_ff_mpo_continuous.yaml @@ -0,0 +1,7 @@ +defaults: + - logger: ff_mpo + - arch: anakin + - system: ff_mpo_continuous + - network: mlp_mpo_continuous + - env: brax/ant + - _self_ diff --git a/stoix/configs/env/jumanji/snake.yaml b/stoix/configs/env/jumanji/snake.yaml index 9fea0a02..a548eef8 100644 --- a/stoix/configs/env/jumanji/snake.yaml +++ b/stoix/configs/env/jumanji/snake.yaml @@ -8,4 +8,7 @@ scenario: name: Snake-v1 task_name: snake -kwargs: {} +kwargs: { + num_rows: 6, + num_cols: 6, +} diff --git a/stoix/configs/logger/ff_mpo.yaml b/stoix/configs/logger/ff_mpo.yaml new file mode 100644 index 00000000..fcd77baf --- /dev/null +++ b/stoix/configs/logger/ff_mpo.yaml @@ -0,0 +1,4 @@ +defaults: + - base_logger + +system_name: ff_mpo diff --git a/stoix/configs/network/mlp.yaml b/stoix/configs/network/mlp.yaml index 671782b7..02bdebda 100644 --- a/stoix/configs/network/mlp.yaml +++ b/stoix/configs/network/mlp.yaml @@ -3,8 +3,8 @@ actor_network: pre_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256, 256] - use_layer_norm: True - activation: silu + use_layer_norm: False + activation: relu action_head: _target_: stoix.networks.heads.CategoricalHead @@ -12,7 +12,7 @@ critic_network: pre_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256, 256] - use_layer_norm: True - activation: silu + use_layer_norm: False + activation: relu critic_head: _target_: stoix.networks.heads.ScalarCriticHead diff --git a/stoix/configs/network/mlp_mpo.yaml b/stoix/configs/network/mlp_mpo.yaml new file mode 100644 index 00000000..4b9b395e --- /dev/null +++ b/stoix/configs/network/mlp_mpo.yaml @@ -0,0 +1,20 @@ +# ---MLP MPO Networks--- +actor_network: + pre_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256, 256] + use_layer_norm: False + activation: relu + action_head: + _target_: stoix.networks.heads.CategoricalHead + +q_network: + input_layer: + _target_: stoix.networks.inputs.ObservationActionInput + pre_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256, 256] + use_layer_norm: False + activation: relu + critic_head: + _target_: stoix.networks.heads.ScalarCriticHead diff --git a/stoix/configs/network/mlp_mpo_continuous.yaml b/stoix/configs/network/mlp_mpo_continuous.yaml new file mode 100644 index 00000000..17e30cc5 --- /dev/null +++ b/stoix/configs/network/mlp_mpo_continuous.yaml @@ -0,0 +1,20 @@ +# ---MLP MPO Networks--- +actor_network: + pre_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256, 256, 256] + use_layer_norm: False + activation: relu + action_head: + _target_: stoix.networks.heads.MultivariateNormalDiagHead + +q_network: + input_layer: + _target_: stoix.networks.inputs.ObservationActionInput + pre_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256, 256, 256] + use_layer_norm: False + activation: relu + critic_head: + _target_: stoix.networks.heads.ScalarCriticHead diff --git a/stoix/configs/system/ff_mpo.yaml b/stoix/configs/system/ff_mpo.yaml new file mode 100644 index 00000000..11d91d03 --- /dev/null +++ b/stoix/configs/system/ff_mpo.yaml @@ -0,0 +1,34 @@ +# --- Defaults FF-MPO --- + +total_timesteps: 1e8 # Set the total environment steps. +# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. +num_updates: ~ # Number of updates +seed: 0 + +# --- RL hyperparameters --- +update_batch_size: 1 # Number of vectorised gradient updates per device. +rollout_length: 8 # Number of environment steps per vectorised environment. +epochs: 16 # Number of sgd steps per rollout. +warmup_steps: 256 # Number of steps to collect before training. +buffer_size: 1_000_000 # size of the replay buffer. +batch_size: 8 # Number of samples to train on per device. +sample_sequence_length: 8 # Number of steps to consider for each element of the batch. +period : 1 # Period of the sampled sequences. +actor_lr: 1e-4 # the learning rate of the policy network optimizer +q_lr: 1e-4 # the learning rate of the Q network network optimizer +dual_lr: 1e-2 # the learning rate of the alpha optimizer +tau: 0.005 # smoothing coefficient for target networks +gamma: 0.99 # discount factor +max_grad_norm: 40.0 # Maximum norm of the gradients for a weight update. +decay_learning_rates: False # Whether learning rates should be linearly decayed during training. +max_abs_reward : 20_000 # maximum absolute reward value +num_samples: 20 # Number of MPO action samples for the policy update. +epsilon: 0.01 # KL constraint on the non-parametric auxiliary policy, the one associated with the dual variable called temperature. +epsilon_policy : 1e-3 # KL constraint on the categorical policy, the one associated with the dual variable called alpha. +init_log_temperature: 3. # initial value for the temperature in log-space, note a softplus (rather than an exp) will be used to transform this. +init_log_alpha: 3. # initial value for the alpha value in log-space, note a softplus (rather than an exp) will be used to transform this. +stochastic_policy_eval: True # whether to use a stochastic policy for Q function target evaluation. +use_online_policy_to_bootstrap: False # whether to use the online policy to bootstrap the Q function targets. +use_retrace : False # whether to use the retrace algorithm for off-policy correction. +retrace_lambda : 0.95 # the retrace lambda parameter. +n_step_for_sequence_bootstrap : 5 # the number of steps to use for the sequence bootstrap. This is only used if use_retrace is False. diff --git a/stoix/configs/system/ff_mpo_continuous.yaml b/stoix/configs/system/ff_mpo_continuous.yaml new file mode 100644 index 00000000..d92888fe --- /dev/null +++ b/stoix/configs/system/ff_mpo_continuous.yaml @@ -0,0 +1,39 @@ +# --- Defaults FF-MPO --- + +total_timesteps: 1e8 # Set the total environment steps. +# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. +num_updates: ~ # Number of updates +seed: 0 + +# --- RL hyperparameters --- +update_batch_size: 1 # Number of vectorised gradient updates per device. +rollout_length: 8 # Number of environment steps per vectorised environment. +epochs: 32 # Number of sgd steps per rollout. +warmup_steps: 256 # Number of steps to collect before training. +buffer_size: 1_000_000 # size of the replay buffer. +batch_size: 8 # Number of samples to train on per device. +sample_sequence_length: 8 # Number of steps to consider for each element of the batch. +period : 1 # Period of the sampled sequences. +actor_lr: 1e-4 # the learning rate of the policy network optimizer +q_lr: 1e-4 # the learning rate of the Q network network optimizer +dual_lr: 1e-2 # the learning rate of the alpha optimizer +tau: 0.005 # smoothing coefficient for target networks +gamma: 0.99 # discount factor +max_grad_norm: 40.0 # Maximum norm of the gradients for a weight update. +decay_learning_rates: False # Whether learning rates should be linearly decayed during training. +max_abs_reward : 20_000 # maximum absolute reward value +num_samples: 20 # Number of MPO action samples for the policy update. +epsilon: 0.01 # KL constraint on the non-parametric auxiliary policy, the one associated with the dual variable called temperature. +epsilon_mean : 1e-3 # KL constraint on the mean of the Gaussian policy, the one associated with the dual variable called alpha_mean. +epsilon_stddev: 1e-5 # KL constraint on the stddev of the Gaussian policy, the one associated with the dual variable called alpha_mean. +init_log_temperature: 10. # initial value for the temperature in log-space, note a softplus (rather than an exp) will be used to transform this. +init_log_alpha_mean: 10. # initial value for the alpha_mean in log-space, note a softplus (rather than an exp) will be used to transform this. +init_log_alpha_stddev: 1000. # initial value for the alpha_stddev in log-space, note a softplus (rather than an exp) will be used to transform this. +per_dim_constraining: True # whether to enforce the KL constraint on each dimension independently; this is the default. Otherwise the overall KL is constrained, which allows some dimensions to change more at the expense of others staying put. +action_penalization: True # whether to use a KL constraint to penalize actions via the MO-MPO algorithm. +epsilon_penalty: 0.001 # KL constraint on the probability of violating the action constraint. +stochastic_policy_eval: True # whether to use a stochastic policy for Q function target evaluation. +use_online_policy_to_bootstrap: False # whether to use the online policy to bootstrap the Q function targets. +use_retrace : False # whether to use the retrace algorithm for off-policy correction. +retrace_lambda : 0.95 # the retrace lambda parameter. +n_step_for_sequence_bootstrap : 5 # the number of steps to use for the sequence bootstrap. This is only used if use_retrace is False. diff --git a/stoix/systems/mpo/continuous_loss.py b/stoix/systems/mpo/continuous_loss.py new file mode 100644 index 00000000..e4edbf11 --- /dev/null +++ b/stoix/systems/mpo/continuous_loss.py @@ -0,0 +1,325 @@ +from typing import Tuple, Union + +import chex +import jax +import jax.numpy as jnp +from tensorflow_probability.substrates.jax.distributions import ( + Distribution, + Independent, + MultivariateNormalDiag, + Normal, +) + +from stoix.systems.mpo.types import DualParams, MPOStats + +# These functions are largely taken from Acme's MPO implementation: + +_MPO_FLOAT_EPSILON = 1e-8 +_MIN_LOG_TEMPERATURE = -18.0 +_MIN_LOG_ALPHA = -18.0 + +Shape = Tuple[int] +DType = type(jnp.float32) + + +def compute_weights_and_temperature_loss( + q_values: chex.Array, + epsilon: float, + temperature: chex.Array, +) -> Tuple[chex.Array, chex.Array]: + """Computes normalized importance weights for the policy optimization. + + Args: + q_values: Q-values associated with the actions sampled from the target + policy; expected shape [N, B]. + epsilon: Desired constraint on the KL between the target and non-parametric + policies. + temperature: Scalar used to temper the Q-values before computing normalized + importance weights from them. This is really the Lagrange dual variable in + the constrained optimization problem, the solution of which is the + non-parametric policy targeted by the policy loss. + + Returns: + Normalized importance weights, used for policy optimization. + Temperature loss, used to adapt the temperature. + """ + + # Temper the given Q-values using the current temperature. + tempered_q_values = jax.lax.stop_gradient(q_values) / temperature + + # Compute the normalized importance weights used to compute expectations with + # respect to the non-parametric policy. + normalized_weights = jax.nn.softmax(tempered_q_values, axis=0) + normalized_weights = jax.lax.stop_gradient(normalized_weights) + + # Compute the temperature loss (dual of the E-step optimization problem). + q_logsumexp = jax.scipy.special.logsumexp(tempered_q_values, axis=0) + log_num_actions = jnp.log(q_values.shape[0] / 1.0) + loss_temperature = epsilon + jnp.mean(q_logsumexp) - log_num_actions + loss_temperature = temperature * loss_temperature + + return normalized_weights, loss_temperature + + +def compute_nonparametric_kl_from_normalized_weights( + normalized_weights: chex.Array, +) -> chex.Array: + """Estimate the actualized KL between the non-parametric and target policies.""" + + # Compute integrand. + num_action_samples = normalized_weights.shape[0] / 1.0 + integrand = jnp.log(num_action_samples * normalized_weights + 1e-8) + + # Return the expectation with respect to the non-parametric policy. + return jnp.sum(normalized_weights * integrand, axis=0) + + +def compute_cross_entropy_loss( + sampled_actions: chex.Array, + normalized_weights: chex.Array, + online_action_distribution: Distribution, +) -> chex.Array: + """Compute cross-entropy online and the reweighted target policy. + + Args: + sampled_actions: samples used in the Monte Carlo integration in the policy + loss. Expected shape is [N, B, ...], where N is the number of sampled + actions and B is the number of sampled states. + normalized_weights: target policy multiplied by the exponentiated Q values + and normalized; expected shape is [N, B]. + online_action_distribution: policy to be optimized. + + Returns: + loss_policy_gradient: the cross-entropy loss that, when differentiated, + produces the policy gradient. + """ + + # Compute the M-step loss. + log_prob = online_action_distribution.log_prob(sampled_actions) + + # Compute the weighted average log-prob using the normalized weights. + loss_policy_gradient = -jnp.sum(log_prob * normalized_weights, axis=0) + + # Return the mean loss over the batch of states. + return jnp.mean(loss_policy_gradient, axis=0) + + +def compute_parametric_kl_penalty_and_dual_loss( + kl: chex.Array, + alpha: chex.Array, + epsilon: float, +) -> Tuple[chex.Array, chex.Array]: + """Computes the KL cost to be added to the Lagragian and its dual loss. + + The KL cost is simply the alpha-weighted KL divergence and it is added as a + regularizer to the policy loss. The dual variable alpha itself has a loss that + can be minimized to adapt the strength of the regularizer to keep the KL + between consecutive updates at the desired target value of epsilon. + + Args: + kl: KL divergence between the target and online policies. + alpha: Lagrange multipliers (dual variables) for the KL constraints. + epsilon: Desired value for the KL. + + Returns: + loss_kl: alpha-weighted KL regularization to be added to the policy loss. + loss_alpha: The Lagrange dual loss minimized to adapt alpha. + """ + + # Compute the mean KL over the batch. + mean_kl = jnp.mean(kl, axis=0) + + # Compute the regularization. + loss_kl = jnp.sum(jax.lax.stop_gradient(alpha) * mean_kl) + + # Compute the dual loss. + loss_alpha = jnp.sum(alpha * (epsilon - jax.lax.stop_gradient(mean_kl))) + + return loss_kl, loss_alpha + + +def clip_dual_params(params: DualParams, per_dim_constraining: bool) -> DualParams: + clipped_params = DualParams( + log_temperature=jnp.maximum(_MIN_LOG_TEMPERATURE, params.log_temperature), + log_alpha_mean=jnp.maximum(_MIN_LOG_ALPHA, params.log_alpha_mean), + log_alpha_stddev=jnp.maximum(_MIN_LOG_ALPHA, params.log_alpha_stddev), + ) + if not per_dim_constraining: + return clipped_params + else: + return clipped_params._replace( + log_penalty_temperature=jnp.maximum( + _MIN_LOG_TEMPERATURE, params.log_penalty_temperature + ) + ) + + +def mpo_loss( + dual_params: DualParams, + online_action_distribution: Union[MultivariateNormalDiag, Independent], + target_action_distribution: Union[MultivariateNormalDiag, Independent], + target_sampled_actions: chex.Array, # Shape [N, B, D]. + target_sampled_q_values: chex.Array, # Shape [N, B]. + epsilon: float, + epsilon_mean: float, + epsilon_stddev: float, + per_dim_constraining: bool, + action_penalization: bool, + epsilon_penalty: float, +) -> Tuple[chex.Array, MPOStats]: + """Computes the decoupled MPO loss. + + Args: + dual_params: parameters tracking the temperature and the dual variables. + online_action_distribution: online distribution returned by the online + policy network; expects batch_dims of [B] and event_dims of [D]. + target_action_distribution: target distribution returned by the target + policy network; expects same shapes as online distribution. + target_sampled_actions: actions sampled from the target policy; expects shape [N, B, D]. + target_sampled_q_values: Q-values associated with each action; expects shape [N, B]. + epsilon: KL constraint on the non-parametric auxiliary policy, the one associated with the + dual variable called temperature. + epsilon_mean: KL constraint on the mean of the Gaussian policy, the one associated with the + dual variable called alpha_mean. + epsilon_stddev: KL constraint on the stddev of the Gaussian policy, the one associated with + the dual variable called alpha_mean. + per_dim_constraining: whether to enforce the KL constraint on each dimension independently; + this is the default. Otherwise the overall KL is constrained, which allows some + dimensions to change more at the expense of others staying put. + action_penalization: whether to use a KL constraint to penalize actions via the MO-MPO + algorithm. + epsilon_penalty: KL constraint on the probability of violating the action constraint. + + Returns: + Loss, combining the policy loss, KL penalty, and dual losses required to + adapt the dual variables. + Stats, for diagnostics and tracking performance. + """ + + # Cast `MultivariateNormalDiag`s to Independent Normals. + # The latter allows us to satisfy KL constraints per-dimension. + if isinstance(target_action_distribution, MultivariateNormalDiag): + target_action_distribution = Independent( + Normal(target_action_distribution.mean(), target_action_distribution.stddev()) + ) + online_action_distribution = Independent( + Normal(online_action_distribution.mean(), online_action_distribution.stddev()) + ) + + # Transform dual variables from log-space. + # Note: using softplus instead of exponential for numerical stability. + temperature = jax.nn.softplus(dual_params.log_temperature) + _MPO_FLOAT_EPSILON + alpha_mean = jax.nn.softplus(dual_params.log_alpha_mean) + _MPO_FLOAT_EPSILON + alpha_stddev = jax.nn.softplus(dual_params.log_alpha_stddev) + _MPO_FLOAT_EPSILON + + # Get online and target means and stddevs in preparation for decomposition. + online_mean = online_action_distribution.distribution.mean() + online_scale = online_action_distribution.distribution.stddev() + target_mean = target_action_distribution.distribution.mean() + target_scale = target_action_distribution.distribution.stddev() + + # Compute normalized importance weights, used to compute expectations with + # respect to the non-parametric policy; and the temperature loss, used to + # adapt the tempering of Q-values. + normalized_weights, loss_temperature = compute_weights_and_temperature_loss( + target_sampled_q_values, epsilon, temperature + ) + + # Only needed for diagnostics: Compute estimated actualized KL between the + # non-parametric and current target policies. + kl_nonparametric = compute_nonparametric_kl_from_normalized_weights(normalized_weights) + + if action_penalization: + # Transform action penalization temperature. + penalty_temperature = ( + jax.nn.softplus(dual_params.log_penalty_temperature) + _MPO_FLOAT_EPSILON + ) + + # Compute action penalization cost. + # Note: the cost is zero in [-1, 1] and quadratic beyond. + diff_out_of_bound = target_sampled_actions - jnp.clip(target_sampled_actions, -1.0, 1.0) + cost_out_of_bound = -jnp.linalg.norm(diff_out_of_bound, axis=-1) + + penalty_normalized_weights, loss_penalty_temperature = compute_weights_and_temperature_loss( + cost_out_of_bound, epsilon_penalty, penalty_temperature + ) + + # Only needed for diagnostics: Compute estimated actualized KL between the + # non-parametric and current target policies. + penalty_kl_nonparametric = compute_nonparametric_kl_from_normalized_weights( + penalty_normalized_weights + ) + + # Combine normalized weights. + normalized_weights += penalty_normalized_weights + loss_temperature += loss_penalty_temperature + + # Decompose the online policy into fixed-mean & fixed-stddev distributions. + # This has been documented as having better performance in bandit settings, + # see e.g. https://arxiv.org/pdf/1812.02256.pdf. + fixed_stddev_distribution = Independent(Normal(loc=online_mean, scale=target_scale)) + fixed_mean_distribution = Independent(Normal(loc=target_mean, scale=online_scale)) + + # Compute the decomposed policy losses. + loss_policy_mean = compute_cross_entropy_loss( + target_sampled_actions, normalized_weights, fixed_stddev_distribution + ) + loss_policy_stddev = compute_cross_entropy_loss( + target_sampled_actions, normalized_weights, fixed_mean_distribution + ) + + # Compute the decomposed KL between the target and online policies. + if per_dim_constraining: + kl_mean = target_action_distribution.distribution.kl_divergence( + fixed_stddev_distribution.distribution + ) # Shape [B, D]. + kl_stddev = target_action_distribution.distribution.kl_divergence( + fixed_mean_distribution.distribution + ) # Shape [B, D]. + else: + kl_mean = target_action_distribution.kl_divergence(fixed_stddev_distribution) # Shape [B]. + kl_stddev = target_action_distribution.kl_divergence(fixed_mean_distribution) # Shape [B]. + + # Compute the alpha-weighted KL-penalty and dual losses to adapt the alphas. + loss_kl_mean, loss_alpha_mean = compute_parametric_kl_penalty_and_dual_loss( + kl_mean, alpha_mean, epsilon_mean + ) + loss_kl_stddev, loss_alpha_stddev = compute_parametric_kl_penalty_and_dual_loss( + kl_stddev, alpha_stddev, epsilon_stddev + ) + + # Combine losses. + loss_policy = loss_policy_mean + loss_policy_stddev + loss_kl_penalty = loss_kl_mean + loss_kl_stddev + loss_dual = loss_alpha_mean + loss_alpha_stddev + loss_temperature + loss = loss_policy + loss_kl_penalty + loss_dual + + # Create statistics. + pi_stddev = online_action_distribution.distribution.stddev() + stats = MPOStats( + # Dual Variables. + dual_alpha_mean=jnp.mean(alpha_mean), + dual_alpha_stddev=jnp.mean(alpha_stddev), + dual_temperature=jnp.mean(temperature), + # Losses. + loss_policy=jnp.mean(loss), + loss_alpha=jnp.mean(loss_alpha_mean + loss_alpha_stddev), + loss_temperature=jnp.mean(loss_temperature), + # KL measurements. + kl_q_rel=jnp.mean(kl_nonparametric) / epsilon, + penalty_kl_q_rel=( + (jnp.mean(penalty_kl_nonparametric) / epsilon_penalty) if action_penalization else None + ), + kl_mean_rel=jnp.mean(kl_mean, axis=0) / epsilon_mean, + kl_stddev_rel=jnp.mean(kl_stddev, axis=0) / epsilon_stddev, + # Q measurements. + q_min=jnp.mean(jnp.min(target_sampled_q_values, axis=0)), + q_max=jnp.mean(jnp.max(target_sampled_q_values, axis=0)), + # If the policy has stddev, log summary stats for this as well. + pi_stddev_min=jnp.mean(jnp.min(pi_stddev, axis=-1)), + pi_stddev_max=jnp.mean(jnp.max(pi_stddev, axis=-1)), + # Condition number of the diagonal covariance (actually, stddev) matrix. + pi_stddev_cond=jnp.mean(jnp.max(pi_stddev, axis=-1) / jnp.min(pi_stddev, axis=-1)), + ) + + return loss, stats diff --git a/stoix/systems/mpo/discrete_loss.py b/stoix/systems/mpo/discrete_loss.py new file mode 100644 index 00000000..4123dad2 --- /dev/null +++ b/stoix/systems/mpo/discrete_loss.py @@ -0,0 +1,163 @@ +from typing import Tuple + +import chex +import jax +import jax.numpy as jnp +from tensorflow_probability.substrates.jax.distributions import Categorical + +from stoix.systems.mpo.types import CategoricalDualParams, CategoricalMPOStats + +# These functions are largely taken from Acme's MPO implementation: + +_MPO_FLOAT_EPSILON = 1e-8 +_MIN_LOG_TEMPERATURE = -18.0 +_MIN_LOG_ALPHA = -18.0 + +Shape = Tuple[int] +DType = type(jnp.float32) + + +def categorical_mpo_loss( + dual_params: CategoricalDualParams, + online_action_distribution: Categorical, + target_action_distribution: Categorical, + q_values: chex.Array, # Shape [D, B]. + epsilon: float, + epsilon_policy: float, +) -> Tuple[chex.Array, CategoricalMPOStats]: + """Computes the MPO loss for a categorical policy. + + Args: + dual_params: parameters tracking the temperature and the dual variables. + online_action_distribution: online distribution returned by the online + policy network; expects batch_dims of [B] and event_dims of [D]. + target_action_distribution: target distribution returned by the target + policy network; expects same shapes as online distribution. + q_values: Q-values associated with every action; expects shape [D, B]. + epsilon: KL constraint on the non-parametric auxiliary policy, the one + associated with the dual variable called temperature. + epsilon_policy: KL constraint on the categorical policy, the one + associated with the dual variable called alpha. + + + Returns: + Loss, combining the policy loss, KL penalty, and dual losses required to + adapt the dual variables. + Stats, for diagnostics and tracking performance. + """ + + q_values = jnp.transpose(q_values) # [D, B] --> [B, D]. + + # Transform dual variables from log-space. + # Note: using softplus instead of exponential for numerical stability. + temperature = get_temperature_from_params(dual_params) + alpha = jax.nn.softplus(dual_params.log_alpha) + _MPO_FLOAT_EPSILON + + # Compute the E-step logits and the temperature loss, used to adapt the + # tempering of Q-values. + ( + logits_e_step, + loss_temperature, + ) = compute_weights_and_temperature_loss( # pytype: disable=wrong-arg-types # jax-ndarray + q_values=q_values, + logits=target_action_distribution.logits, + epsilon=epsilon, + temperature=temperature, + ) + action_distribution_e_step = Categorical(logits=logits_e_step) + + # Only needed for diagnostics: Compute estimated actualized KL between the + # non-parametric and current target policies. + kl_nonparametric = action_distribution_e_step.kl_divergence(target_action_distribution) + + # Compute the policy loss. + loss_policy = action_distribution_e_step.cross_entropy(online_action_distribution) + loss_policy = jnp.mean(loss_policy) + + # Compute the regularization. + kl = target_action_distribution.kl_divergence(online_action_distribution) + mean_kl = jnp.mean(kl, axis=0) + loss_kl = jax.lax.stop_gradient(alpha) * mean_kl + + # Compute the dual loss. + loss_alpha = alpha * (epsilon_policy - jax.lax.stop_gradient(mean_kl)) + + # Combine losses. + loss_dual = loss_alpha + loss_temperature + loss = loss_policy + loss_kl + loss_dual + + # Create statistics. + stats = CategoricalMPOStats( # pytype: disable=wrong-arg-types # jnp-type + # Dual Variables. + dual_alpha=jnp.mean(alpha), + dual_temperature=jnp.mean(temperature), + # Losses. + loss_e_step=loss_policy, + loss_m_step=loss_kl, + loss_dual=loss_dual, + loss_policy=jnp.mean(loss), + loss_alpha=jnp.mean(loss_alpha), + loss_temperature=jnp.mean(loss_temperature), + # KL measurements. + kl_q_rel=jnp.mean(kl_nonparametric) / epsilon, + kl_mean_rel=mean_kl / epsilon_policy, + # Q measurements. + q_min=jnp.mean(jnp.min(q_values, axis=0)), + q_max=jnp.mean(jnp.max(q_values, axis=0)), + entropy_online=jnp.mean(online_action_distribution.entropy()), + entropy_target=jnp.mean(target_action_distribution.entropy()), + ) + + return loss, stats + + +def compute_weights_and_temperature_loss( + q_values: chex.Array, + logits: chex.Array, + epsilon: float, + temperature: chex.Array, +) -> Tuple[chex.Array, chex.Array]: + """Computes normalized importance weights for the policy optimization. + + Args: + q_values: Q-values associated with the actions sampled from the target + policy; expected shape [B, D]. + logits: Parameters to the categorical distribution with respect to which the + expectations are going to be computed. + epsilon: Desired constraint on the KL between the target and non-parametric + policies. + temperature: Scalar used to temper the Q-values before computing normalized + importance weights from them. This is really the Lagrange dual variable in + the constrained optimization problem, the solution of which is the + non-parametric policy targeted by the policy loss. + + Returns: + Normalized importance weights, used for policy optimization. + Temperature loss, used to adapt the temperature. + """ + + # Temper the given Q-values using the current temperature. + tempered_q_values = jax.lax.stop_gradient(q_values) / temperature + + # Compute the E-step normalized logits. + unnormalized_logits = tempered_q_values + jax.nn.log_softmax(logits, axis=-1) + logits_e_step = jax.nn.log_softmax(unnormalized_logits, axis=-1) + + # Compute the temperature loss (dual of the E-step optimization problem). + # Note that the log normalizer will be the same for all actions, so we choose + # only the first one. + log_normalizer = unnormalized_logits[:, 0] - logits_e_step[:, 0] + loss_temperature = temperature * (epsilon + jnp.mean(log_normalizer)) + + return logits_e_step, loss_temperature + + +def clip_categorical_mpo_params(params: CategoricalDualParams) -> CategoricalDualParams: + return params._replace( + log_temperature=jnp.maximum(_MIN_LOG_TEMPERATURE, params.log_temperature), + log_alpha=jnp.maximum(_MIN_LOG_ALPHA, params.log_alpha), + ) + + +def get_temperature_from_params(params: CategoricalDualParams) -> chex.Array: + return jax.nn.softplus(params.log_temperature) + _MPO_FLOAT_EPSILON diff --git a/stoix/systems/mpo/ff_mpo.py b/stoix/systems/mpo/ff_mpo.py new file mode 100644 index 00000000..8e14e31e --- /dev/null +++ b/stoix/systems/mpo/ff_mpo.py @@ -0,0 +1,715 @@ +import copy +import time +from typing import Any, Callable, Dict, Tuple + +import chex +import flashbax as fbx +import flax +import hydra +import jax +import jax.numpy as jnp +import optax +import rlax +from colorama import Fore, Style +from flashbax.buffers.trajectory_buffer import BufferState +from flax.core.frozen_dict import FrozenDict +from jumanji.env import Environment +from jumanji.types import TimeStep +from omegaconf import DictConfig, OmegaConf +from rich.pretty import pprint + +from stoix.evaluator import evaluator_setup +from stoix.networks.base import CompositeNetwork +from stoix.networks.base import FeedForwardActor as Actor +from stoix.systems.mpo.discrete_loss import ( + categorical_mpo_loss, + clip_categorical_mpo_params, +) +from stoix.systems.mpo.types import ( + ActorAndTarget, + CategoricalDualParams, + MPOLearnerState, + MPOOptStates, + MPOParams, + SequenceStep, +) +from stoix.systems.q_learning.types import QsAndTarget +from stoix.systems.sac.types import ContinuousQApply +from stoix.types import ActorApply, ExperimentOutput, LearnerFn, LogEnvState +from stoix.utils import make_env as environments +from stoix.utils.checkpointing import Checkpointer +from stoix.utils.jax import unreplicate_batch_dim, unreplicate_n_dims +from stoix.utils.logger import LogEvent, StoixLogger +from stoix.utils.multistep import ( + batch_n_step_bootstrapped_returns, + batch_retrace_continuous, +) +from stoix.utils.total_timestep_checker import check_total_timesteps +from stoix.utils.training import make_learning_rate + + +def get_warmup_fn( + env: Environment, + params: MPOParams, + actor_apply_fn: ActorApply, + buffer_add_fn: Callable, + config: DictConfig, +) -> Callable: + def warmup( + env_states: LogEnvState, timesteps: TimeStep, buffer_states: BufferState, keys: chex.PRNGKey + ) -> Tuple[LogEnvState, TimeStep, BufferState, chex.PRNGKey]: + def _env_step( + carry: Tuple[LogEnvState, TimeStep, chex.PRNGKey], _: Any + ) -> Tuple[Tuple[LogEnvState, TimeStep, chex.PRNGKey], SequenceStep]: + """Step the environment.""" + + env_state, last_timestep, key = carry + # SELECT ACTION + key, policy_key = jax.random.split(key) + actor_policy = actor_apply_fn(params.actor_params.online, last_timestep.observation) + action = actor_policy.sample(seed=policy_key) + log_prob = actor_policy.log_prob(action) + + # STEP ENVIRONMENT + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) + + # LOG EPISODE METRICS + done = timestep.last().reshape(-1) + info = timestep.extras["episode_metrics"] + + sequence_step = SequenceStep( + last_timestep.observation, action, timestep.reward, done, log_prob, info + ) + + return (env_state, timestep, key), sequence_step + + # STEP ENVIRONMENT FOR ROLLOUT LENGTH + (env_states, timesteps, keys), traj_batch = jax.lax.scan( + _env_step, (env_states, timesteps, keys), None, config.system.warmup_steps + ) + + # Add the trajectory to the buffer. + # Swap the batch and time axes. + traj_batch = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), traj_batch) + buffer_states = buffer_add_fn(buffer_states, traj_batch) + + return env_states, timesteps, keys, buffer_states + + batched_warmup_step: Callable = jax.vmap( + warmup, in_axes=(0, 0, 0, 0), out_axes=(0, 0, 0, 0), axis_name="batch" + ) + + return batched_warmup_step + + +def get_learner_fn( + env: Environment, + apply_fns: Tuple[ActorApply, ContinuousQApply], + update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn, optax.TransformUpdateFn], + buffer_fns: Tuple[Callable, Callable], + config: DictConfig, +) -> LearnerFn[MPOLearnerState]: + """Get the learner function.""" + + # Get apply and update functions for actor and critic networks. + actor_apply_fn, q_apply_fn = apply_fns + actor_update_fn, q_update_fn, dual_update_fn = update_fns + buffer_add_fn, buffer_sample_fn = buffer_fns + + def _update_step(learner_state: MPOLearnerState, _: Any) -> Tuple[MPOLearnerState, Tuple]: + def _env_step( + learner_state: MPOLearnerState, _: Any + ) -> Tuple[MPOLearnerState, SequenceStep]: + """Step the environment.""" + params, opt_states, buffer_state, key, env_state, last_timestep = learner_state + + # SELECT ACTION + key, policy_key = jax.random.split(key) + actor_policy = actor_apply_fn(params.actor_params.online, last_timestep.observation) + action = actor_policy.sample(seed=policy_key) + log_prob = actor_policy.log_prob(action) + + # STEP ENVIRONMENT + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) + + # LOG EPISODE METRICS + done = timestep.last().reshape(-1) + info = timestep.extras["episode_metrics"] + + sequence_step = SequenceStep( + last_timestep.observation, action, timestep.reward, done, log_prob, info + ) + + learner_state = MPOLearnerState( + params, opt_states, buffer_state, key, env_state, timestep + ) + return learner_state, sequence_step + + # STEP ENVIRONMENT FOR ROLLOUT LENGTH + learner_state, traj_batch = jax.lax.scan( + _env_step, learner_state, None, config.system.rollout_length + ) + + params, opt_states, buffer_state, key, env_state, last_timestep = learner_state + + # Add the trajectory to the buffer. + # Swap the batch and time axes. + traj_batch = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), traj_batch) + buffer_state = buffer_add_fn(buffer_state, traj_batch) + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def _actor_loss_fn( + online_actor_params: FrozenDict, + dual_params: CategoricalDualParams, + target_actor_params: FrozenDict, + target_q_params: FrozenDict, + sequence: SequenceStep, + ) -> chex.Array: + # Reshape the observations to [B*T, ...]. + reshaped_obs = jax.tree_map( + lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), sequence.obs + ) + batch_length = sequence.action.shape[0] * sequence.action.shape[1] # B*T + + online_actor_policy = actor_apply_fn(online_actor_params, reshaped_obs) + target_actor_policy = actor_apply_fn(target_actor_params, reshaped_obs) + # In discrete MPO, we evaluate all actions instead of sampling. + a_improvement = jnp.arange(config.system.action_dim).astype(jnp.float32) + a_improvement = jnp.tile( + a_improvement[..., jnp.newaxis], [1, batch_length] + ) # [D, B*T] + a_improvement = jax.nn.one_hot(a_improvement, config.system.action_dim) + target_q_values = jax.vmap(q_apply_fn, in_axes=(None, None, 0))( + target_q_params, reshaped_obs, a_improvement + ) # [D, B*T] + + # Compute the policy and dual loss. + loss, loss_info = categorical_mpo_loss( + dual_params=dual_params, + online_action_distribution=online_actor_policy, + target_action_distribution=target_actor_policy, + q_values=target_q_values, + epsilon=config.system.epsilon, + epsilon_policy=config.system.epsilon_policy, + ) + + return jnp.mean(loss), loss_info + + def _q_loss_fn( + online_q_params: FrozenDict, + target_q_params: FrozenDict, + online_actor_params: FrozenDict, + target_actor_params: FrozenDict, + sequences: SequenceStep, + rng_key: chex.PRNGKey, + ) -> jnp.ndarray: + + online_actor_policy = actor_apply_fn( + online_actor_params, sequences.obs + ) # [B, T, ...] + target_actor_policy = actor_apply_fn( + target_actor_params, sequences.obs + ) # [B, T, ...] + a_t = jax.nn.one_hot(sequences.action, config.system.action_dim) # [B, T, ...] + online_q_t = q_apply_fn(online_q_params, sequences.obs, a_t) # [B, T] + + # Cast and clip rewards. + discount = 1.0 - sequence.done.astype(jnp.float32) + d_t = (discount * config.system.gamma).astype(jnp.float32) + r_t = jnp.clip( + sequence.reward, -config.system.max_abs_reward, config.system.max_abs_reward + ).astype(jnp.float32) + + # Policy to use for policy evaluation and bootstrapping. + if config.system.use_online_policy_to_bootstrap: + policy_to_evaluate = online_actor_policy + else: + policy_to_evaluate = target_actor_policy + + # Action(s) to use for policy evaluation; shape [N, B, T]. + if config.system.stochastic_policy_eval: + a_evaluation = policy_to_evaluate.sample( + seed=rng_key, sample_shape=config.system.num_samples + ) # [N, B, T, ...] + else: + a_evaluation = policy_to_evaluate.mode()[jnp.newaxis, ...] # [N=1, B, T, ...] + + # Add a stopgrad in case we use the online policy for evaluation. + a_evaluation = jax.lax.stop_gradient(a_evaluation) + a_evaluation = jax.nn.one_hot(a_evaluation, config.system.action_dim) + + # Compute the Q-values for the next state-action pairs; [N, B, T]. + q_values = jax.vmap(q_apply_fn, in_axes=(None, None, 0))( + target_q_params, sequences.obs, a_evaluation + ) + + # When policy_eval_stochastic == True, this corresponds to expected SARSA. + # Otherwise, the mean is a no-op. + v_t = jnp.mean(q_values, axis=0) # [B, T] + + if config.system.use_retrace: + # Compute the log-rhos for the retrace targets. + log_rhos = target_actor_policy.log_prob(sequences.action) - sequences.log_prob + + # Compute target Q-values + target_q_t = q_apply_fn(target_q_params, sequences.obs, a_t) # [B, T] + + # Compute retrace targets. + # These targets use the rewards and discounts as in normal TD-learning but + # they use a mix of bootstrapped values V(s') and Q(s', a'), weighing the + # latter based on how likely a' is under the current policy (s' and a' are + # samples from replay). + # See [Munos et al., 2016](https://arxiv.org/abs/1606.02647) for more. + retrace_error = batch_retrace_continuous( + online_q_t[:, :-1], + target_q_t[:, 1:-1], + v_t[:, 1:], + r_t[:, :-1], + d_t[:, :-1], + log_rhos[:, 1:-1], + config.system.retrace_lambda, + ) + q_loss = rlax.l2_loss(retrace_error).mean() + else: + n_step_value_target = batch_n_step_bootstrapped_returns( + r_t[:, :-1], + d_t[:, :-1], + v_t[:, 1:], + config.system.n_step_for_sequence_bootstrap, + ) + td_error = online_q_t[:, :-1] - n_step_value_target + q_loss = rlax.l2_loss(td_error).mean() + + loss_info = { + "q_loss": q_loss, + } + + return q_loss, loss_info + + params, opt_states, buffer_state, key = update_state + + key, sample_key, q_key = jax.random.split(key, num=3) + + # SAMPLE SEQUENCES + sequence_sample = buffer_sample_fn(buffer_state, sample_key) + sequence: SequenceStep = sequence_sample.experience + + # CALCULATE ACTOR AND DUAL LOSS + actor_dual_grad_fn = jax.grad(_actor_loss_fn, argnums=(0, 1), has_aux=True) + actor_dual_grads, actor_loss_info = actor_dual_grad_fn( + params.actor_params.online, + params.dual_params, + params.actor_params.target, + params.q_params.target, + sequence, + ) + + # CALCULATE Q LOSS + q_grad_fn = jax.grad(_q_loss_fn, has_aux=True) + q_grads, q_loss_info = q_grad_fn( + params.q_params.online, + params.q_params.target, + params.actor_params.online, + params.actor_params.target, + sequence, + q_key, + ) + + # Compute the parallel mean (pmean) over the batch. + # This calculation is inspired by the Anakin architecture demo notebook. + # available at https://tinyurl.com/26tdzs5x + # This pmean could be a regular mean as the batch axis is on the same device. + actor_dual_grads, actor_loss_info = jax.lax.pmean( + (actor_dual_grads, actor_loss_info), axis_name="batch" + ) + # pmean over devices. + actor_dual_grads, actor_loss_info = jax.lax.pmean( + (actor_dual_grads, actor_loss_info), axis_name="device" + ) + + q_grads, q_loss_info = jax.lax.pmean((q_grads, q_loss_info), axis_name="batch") + q_grads, q_loss_info = jax.lax.pmean((q_grads, q_loss_info), axis_name="device") + + actor_grads, dual_grads = actor_dual_grads + + # UPDATE ACTOR PARAMS AND OPTIMISER STATE + actor_updates, actor_new_opt_state = actor_update_fn( + actor_grads, opt_states.actor_opt_state + ) + actor_new_online_params = optax.apply_updates(params.actor_params.online, actor_updates) + + # UPDATE DUAL PARAMS AND OPTIMISER STATE + dual_updates, dual_new_opt_state = dual_update_fn(dual_grads, opt_states.dual_opt_state) + dual_new_params = optax.apply_updates(params.dual_params, dual_updates) + dual_new_params = clip_categorical_mpo_params(dual_new_params) + + # UPDATE Q PARAMS AND OPTIMISER STATE + q_updates, q_new_opt_state = q_update_fn(q_grads, opt_states.q_opt_state) + q_new_online_params = optax.apply_updates(params.q_params.online, q_updates) + # Target network polyak update. + (new_target_actor_params, new_target_q_params) = optax.incremental_update( + (actor_new_online_params, q_new_online_params), + (params.actor_params.target, params.q_params.target), + config.system.tau, + ) + + actor_new_params = ActorAndTarget(actor_new_online_params, new_target_actor_params) + q_new_params = QsAndTarget(q_new_online_params, new_target_q_params) + + # PACK NEW PARAMS AND OPTIMISER STATE + new_params = MPOParams(actor_new_params, q_new_params, dual_new_params) + new_opt_state = MPOOptStates(actor_new_opt_state, q_new_opt_state, dual_new_opt_state) + + # PACK LOSS INFO + loss_info = actor_loss_info._asdict() + loss_info = { + **loss_info, + "value_loss": q_loss_info["q_loss"], + } + return (new_params, new_opt_state, buffer_state, key), loss_info + + update_state = (params, opt_states, buffer_state, key) + + # UPDATE EPOCHS + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.epochs + ) + + params, opt_states, buffer_state, key = update_state + learner_state = MPOLearnerState( + params, opt_states, buffer_state, key, env_state, last_timestep + ) + metric = traj_batch.info + return learner_state, (metric, loss_info) + + def learner_fn(learner_state: MPOLearnerState) -> ExperimentOutput[MPOLearnerState]: + """Learner function. + + This function represents the learner, it updates the network parameters + by iteratively applying the `_update_step` function for a fixed number of + updates. The `_update_step` function is vectorized over a batch of inputs. + """ + + batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") + + learner_state, (episode_info, loss_info) = jax.lax.scan( + batched_update_step, learner_state, None, config.system.num_updates_per_eval + ) + return ExperimentOutput( + learner_state=learner_state, + episode_metrics=episode_info, + train_metrics=loss_info, + ) + + return learner_fn + + +def learner_setup( + env: Environment, keys: chex.Array, config: DictConfig +) -> Tuple[LearnerFn[MPOLearnerState], Actor, MPOLearnerState]: + """Initialise learner_fn, network, optimiser, environment and states.""" + # Get available TPU cores. + n_devices = len(jax.devices()) + + # Get number of actions or action dimension from the environment. + action_dim = int(env.action_spec().num_values) + config.system.action_dim = action_dim + + # PRNG keys. + key, actor_net_key, q_net_key = keys + + # Define actor_network, q_network and optimiser. + actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) + actor_action_head = hydra.utils.instantiate( + config.network.actor_network.action_head, action_dim=action_dim + ) + actor_network = Actor(torso=actor_torso, action_head=actor_action_head) + q_network_input = hydra.utils.instantiate(config.network.q_network.input_layer) + q_network_torso = hydra.utils.instantiate(config.network.q_network.pre_torso) + q_network_head = hydra.utils.instantiate(config.network.q_network.critic_head) + q_network = CompositeNetwork([q_network_input, q_network_torso, q_network_head]) + + actor_lr = make_learning_rate(config.system.actor_lr, config, config.system.epochs) + q_lr = make_learning_rate(config.system.q_lr, config, config.system.epochs) + + actor_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(actor_lr, eps=1e-5), + ) + q_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(q_lr, eps=1e-5), + ) + + # Initialise observation: Select only obs for a single agent. + init_x = env.observation_spec().generate_value() + init_x = jax.tree_util.tree_map(lambda x: x[None, ...], init_x) + init_a = jnp.zeros((1, action_dim)) + + # Initialise actor params and optimiser state. + actor_params = actor_network.init(actor_net_key, init_x) + target_actor_params = actor_params + actor_opt_state = actor_optim.init(actor_params) + + # Initialise q params and optimiser state. + online_q_params = q_network.init(q_net_key, init_x, init_a) + target_q_params = online_q_params + q_opt_state = q_optim.init(online_q_params) + + # Initialise MPO Dual params and optimiser state. + log_temperature = jnp.full([1], config.system.init_log_temperature, dtype=jnp.float32) + + log_alpha = jnp.full([1], config.system.init_log_alpha, dtype=jnp.float32) + + dual_params = CategoricalDualParams( + log_temperature=log_temperature, + log_alpha=log_alpha, + ) + + dual_lr = make_learning_rate(config.system.dual_lr, config, config.system.epochs) + dual_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(dual_lr, eps=1e-5), + ) + dual_opt_state = dual_optim.init(dual_params) + + params = MPOParams( + ActorAndTarget(actor_params, target_actor_params), + QsAndTarget(online_q_params, target_q_params), + dual_params, + ) + opt_states = MPOOptStates(actor_opt_state, q_opt_state, dual_opt_state) + + actor_network_apply_fn = actor_network.apply + q_network_apply_fn = q_network.apply + + # Pack apply and update functions. + apply_fns = (actor_network_apply_fn, q_network_apply_fn) + update_fns = (actor_optim.update, q_optim.update, dual_optim.update) + + # Create replay buffer + dummy_sequence_step = SequenceStep( + obs=jax.tree_util.tree_map(lambda x: x.squeeze(0), init_x), + action=jnp.zeros((), dtype=int), + reward=jnp.zeros((), dtype=float), + done=jnp.zeros((), dtype=bool), + log_prob=jnp.zeros((), dtype=float), + info={"episode_return": 0.0, "episode_length": 0}, + ) + + buffer_fn = fbx.make_trajectory_buffer( + max_size=config.system.buffer_size, + min_length_time_axis=config.system.sample_sequence_length, + sample_batch_size=config.system.batch_size, + sample_sequence_length=config.system.sample_sequence_length, + period=config.system.period, + add_batch_size=config.arch.num_envs, + ) + buffer_fns = (buffer_fn.add, buffer_fn.sample) + buffer_states = buffer_fn.init(dummy_sequence_step) + + # Get batched iterated update and replicate it to pmap it over cores. + learn = get_learner_fn(env, apply_fns, update_fns, buffer_fns, config) + learn = jax.pmap(learn, axis_name="device") + + warmup = get_warmup_fn(env, params, actor_network_apply_fn, buffer_fn.add, config) + warmup = jax.pmap(warmup, axis_name="device") + + # Initialise environment states and timesteps: across devices and batches. + key, *env_keys = jax.random.split( + key, n_devices * config.system.update_batch_size * config.arch.num_envs + 1 + ) + env_states, timesteps = jax.vmap(env.reset, in_axes=(0))( + jnp.stack(env_keys), + ) + reshape_states = lambda x: x.reshape( + (n_devices, config.system.update_batch_size, config.arch.num_envs) + x.shape[1:] + ) + # (devices, update batch size, num_envs, ...) + env_states = jax.tree_map(reshape_states, env_states) + timesteps = jax.tree_map(reshape_states, timesteps) + + # Load model from checkpoint if specified. + if config.logger.checkpointing.load_model: + loaded_checkpoint = Checkpointer( + model_name=config.logger.system_name, + **config.logger.checkpointing.load_args, # Other checkpoint args + ) + # Restore the learner state from the checkpoint + restored_params, _ = loaded_checkpoint.restore_params(TParams=MPOParams) + # Update the params + params = restored_params + + # Define params to be replicated across devices and batches. + key, step_keys, warmup_keys = jax.random.split(key, num=3) + + replicate_learner = (params, opt_states, buffer_states, step_keys, warmup_keys) + + # Duplicate learner for update_batch_size. + broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size,) + x.shape) + replicate_learner = jax.tree_map(broadcast, replicate_learner) + + # Duplicate learner across devices. + replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()) + + # Initialise learner state. + params, opt_states, buffer_states, step_keys, warmup_keys = replicate_learner + # Warmup the buffer. + env_states, timesteps, keys, buffer_states = warmup( + env_states, timesteps, buffer_states, warmup_keys + ) + init_learner_state = MPOLearnerState( + params, opt_states, buffer_states, step_keys, env_states, timesteps + ) + + return learn, actor_network, init_learner_state + + +def run_experiment(_config: DictConfig) -> None: + """Runs experiment.""" + config = copy.deepcopy(_config) + + # Calculate total timesteps. + n_devices = len(jax.devices()) + config = check_total_timesteps(config) + assert ( + config.system.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + + # Create the enviroments for train and eval. + env, eval_env = environments.make(config=config) + + # PRNG keys. + key, key_e, actor_net_key, q_net_key = jax.random.split( + jax.random.PRNGKey(config["system"]["seed"]), num=4 + ) + + # Setup learner. + learn, actor_network, learner_state = learner_setup( + env, (key, actor_net_key, q_net_key), config + ) + + # Setup evaluator. + evaluator, absolute_metric_evaluator, (trained_params, eval_keys) = evaluator_setup( + eval_env=eval_env, + key_e=key_e, + network=actor_network, + params=learner_state.params.actor_params.online, + config=config, + ) + + # Calculate number of updates per evaluation. + config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation + steps_per_rollout = ( + n_devices + * config.system.num_updates_per_eval + * config.system.rollout_length + * config.system.update_batch_size + * config.arch.num_envs + ) + + # Logger setup + logger = StoixLogger(config) + cfg: Dict = OmegaConf.to_container(config, resolve=True) + cfg["arch"]["devices"] = jax.devices() + pprint(cfg) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=config, # Save all config as metadata in the checkpoint + model_name=config.logger.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + # Run experiment for a total number of evaluations. + max_episode_return = jnp.float32(-1e6) + best_params = unreplicate_batch_dim(learner_state.params.actor_params.online) + for eval_step in range(config.arch.num_evaluation): + # Train. + start_time = time.time() + + learner_output = learn(learner_state) + jax.block_until_ready(learner_output) + + # Log the results of the training. + elapsed_time = time.time() - start_time + t = int(steps_per_rollout * (eval_step + 1)) + learner_output.episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + + # Separately log timesteps, actoring metrics and training metrics. + logger.log({"timestep": t}, t, eval_step, LogEvent.MISC) + logger.log(learner_output.episode_metrics, t, eval_step, LogEvent.ACT) + logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) + + # Prepare for evaluation. + start_time = time.time() + trained_params = unreplicate_batch_dim( + learner_output.learner_state.params.actor_params.online + ) # Select only actor params + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + # Evaluate. + evaluator_output = evaluator(trained_params, eval_keys) + jax.block_until_ready(evaluator_output) + + # Log the results of the evaluation. + elapsed_time = time.time() - start_time + episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"]) + + evaluator_output.episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL) + + if save_checkpoint: + checkpointer.save( + timestep=int(steps_per_rollout * (eval_step + 1)), + unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params = copy.deepcopy(trained_params) + max_episode_return = episode_return + + # Update runner state to continue training. + learner_state = learner_output.learner_state + + # Measure absolute metric. + if config.arch.absolute_metric: + start_time = time.time() + + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + evaluator_output = absolute_metric_evaluator(best_params, eval_keys) + jax.block_until_ready(evaluator_output) + + elapsed_time = time.time() - start_time + t = int(steps_per_rollout * (eval_step + 1)) + evaluator_output.episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.ABSOLUTE) + + # Stop the logger. + logger.stop() + + +@hydra.main(config_path="../../configs", config_name="default_ff_mpo.yaml", version_base="1.2") +def hydra_entry_point(cfg: DictConfig) -> None: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + + # Run experiment. + run_experiment(cfg) + + print(f"{Fore.CYAN}{Style.BRIGHT}MPO experiment completed{Style.RESET_ALL}") + + +if __name__ == "__main__": + hydra_entry_point() diff --git a/stoix/systems/mpo/ff_mpo_continuous.py b/stoix/systems/mpo/ff_mpo_continuous.py new file mode 100644 index 00000000..426c050e --- /dev/null +++ b/stoix/systems/mpo/ff_mpo_continuous.py @@ -0,0 +1,737 @@ +import copy +import time +from typing import Any, Callable, Dict, Tuple + +import chex +import flashbax as fbx +import flax +import hydra +import jax +import jax.numpy as jnp +import optax +import rlax +from colorama import Fore, Style +from flashbax.buffers.trajectory_buffer import BufferState +from flax.core.frozen_dict import FrozenDict +from jumanji.env import Environment +from jumanji.types import TimeStep +from omegaconf import DictConfig, OmegaConf +from rich.pretty import pprint + +from stoix.evaluator import evaluator_setup +from stoix.networks.base import CompositeNetwork +from stoix.networks.base import FeedForwardActor as Actor +from stoix.systems.mpo.continuous_loss import clip_dual_params, mpo_loss +from stoix.systems.mpo.types import ( + ActorAndTarget, + DualParams, + MPOLearnerState, + MPOOptStates, + MPOParams, + SequenceStep, +) +from stoix.systems.q_learning.types import QsAndTarget +from stoix.systems.sac.types import ContinuousQApply +from stoix.types import ActorApply, ExperimentOutput, LearnerFn, LogEnvState +from stoix.utils import make_env as environments +from stoix.utils.checkpointing import Checkpointer +from stoix.utils.jax import unreplicate_batch_dim, unreplicate_n_dims +from stoix.utils.logger import LogEvent, StoixLogger +from stoix.utils.multistep import ( + batch_n_step_bootstrapped_returns, + batch_retrace_continuous, +) +from stoix.utils.total_timestep_checker import check_total_timesteps +from stoix.utils.training import make_learning_rate + + +def get_warmup_fn( + env: Environment, + params: MPOParams, + actor_apply_fn: ActorApply, + buffer_add_fn: Callable, + config: DictConfig, +) -> Callable: + def warmup( + env_states: LogEnvState, timesteps: TimeStep, buffer_states: BufferState, keys: chex.PRNGKey + ) -> Tuple[LogEnvState, TimeStep, BufferState, chex.PRNGKey]: + def _env_step( + carry: Tuple[LogEnvState, TimeStep, chex.PRNGKey], _: Any + ) -> Tuple[Tuple[LogEnvState, TimeStep, chex.PRNGKey], SequenceStep]: + """Step the environment.""" + + env_state, last_timestep, key = carry + # SELECT ACTION + key, policy_key = jax.random.split(key) + actor_policy = actor_apply_fn(params.actor_params.online, last_timestep.observation) + action = actor_policy.sample(seed=policy_key) + log_prob = actor_policy.log_prob(action) + + # STEP ENVIRONMENT + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) + + # LOG EPISODE METRICS + done = timestep.last().reshape(-1) + info = timestep.extras["episode_metrics"] + + sequence_step = SequenceStep( + last_timestep.observation, action, timestep.reward, done, log_prob, info + ) + + return (env_state, timestep, key), sequence_step + + # STEP ENVIRONMENT FOR ROLLOUT LENGTH + (env_states, timesteps, keys), traj_batch = jax.lax.scan( + _env_step, (env_states, timesteps, keys), None, config.system.warmup_steps + ) + + # Add the trajectory to the buffer. + # Swap the batch and time axes. + traj_batch = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), traj_batch) + buffer_states = buffer_add_fn(buffer_states, traj_batch) + + return env_states, timesteps, keys, buffer_states + + batched_warmup_step: Callable = jax.vmap( + warmup, in_axes=(0, 0, 0, 0), out_axes=(0, 0, 0, 0), axis_name="batch" + ) + + return batched_warmup_step + + +def get_learner_fn( + env: Environment, + apply_fns: Tuple[ActorApply, ContinuousQApply], + update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn, optax.TransformUpdateFn], + buffer_fns: Tuple[Callable, Callable], + config: DictConfig, +) -> LearnerFn[MPOLearnerState]: + """Get the learner function.""" + + # Get apply and update functions for actor and critic networks. + actor_apply_fn, q_apply_fn = apply_fns + actor_update_fn, q_update_fn, dual_update_fn = update_fns + buffer_add_fn, buffer_sample_fn = buffer_fns + + def _update_step(learner_state: MPOLearnerState, _: Any) -> Tuple[MPOLearnerState, Tuple]: + def _env_step( + learner_state: MPOLearnerState, _: Any + ) -> Tuple[MPOLearnerState, SequenceStep]: + """Step the environment.""" + params, opt_states, buffer_state, key, env_state, last_timestep = learner_state + + # SELECT ACTION + key, policy_key = jax.random.split(key) + actor_policy = actor_apply_fn(params.actor_params.online, last_timestep.observation) + action = actor_policy.sample(seed=policy_key) + log_prob = actor_policy.log_prob(action) + + # STEP ENVIRONMENT + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) + + # LOG EPISODE METRICS + done = timestep.last().reshape(-1) + info = timestep.extras["episode_metrics"] + + sequence_step = SequenceStep( + last_timestep.observation, action, timestep.reward, done, log_prob, info + ) + + learner_state = MPOLearnerState( + params, opt_states, buffer_state, key, env_state, timestep + ) + return learner_state, sequence_step + + # STEP ENVIRONMENT FOR ROLLOUT LENGTH + learner_state, traj_batch = jax.lax.scan( + _env_step, learner_state, None, config.system.rollout_length + ) + + params, opt_states, buffer_state, key, env_state, last_timestep = learner_state + + # Add the trajectory to the buffer. + # Swap the batch and time axes. + traj_batch = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), traj_batch) + buffer_state = buffer_add_fn(buffer_state, traj_batch) + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def _actor_loss_fn( + online_actor_params: FrozenDict, + dual_params: DualParams, + target_actor_params: FrozenDict, + target_q_params: FrozenDict, + sequence: SequenceStep, + key: chex.PRNGKey, + ) -> chex.Array: + # Reshape the observations to [B*T, ...]. + reshaped_obs = jax.tree_map( + lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), sequence.obs + ) + + online_actor_policy = actor_apply_fn(online_actor_params, reshaped_obs) + target_actor_policy = actor_apply_fn(target_actor_params, reshaped_obs) + target_sampled_actions = target_actor_policy.sample( + seed=key, sample_shape=config.system.num_samples + ) + target_sampled_q_values = jax.vmap(q_apply_fn, in_axes=(None, None, 0))( + target_q_params, reshaped_obs, target_sampled_actions + ) + + # Compute the policy and dual loss. + loss, loss_info = mpo_loss( + dual_params=dual_params, + online_action_distribution=online_actor_policy, + target_action_distribution=target_actor_policy, + target_sampled_actions=target_sampled_actions, + target_sampled_q_values=target_sampled_q_values, + epsilon=config.system.epsilon, + epsilon_mean=config.system.epsilon_mean, + epsilon_stddev=config.system.epsilon_stddev, + per_dim_constraining=config.system.per_dim_constraining, + action_penalization=config.system.action_penalization, + epsilon_penalty=config.system.epsilon_penalty, + ) + + return jnp.mean(loss), loss_info + + def _q_loss_fn( + online_q_params: FrozenDict, + target_q_params: FrozenDict, + online_actor_params: FrozenDict, + target_actor_params: FrozenDict, + sequences: SequenceStep, + rng_key: chex.PRNGKey, + ) -> jnp.ndarray: + + online_actor_policy = actor_apply_fn( + online_actor_params, sequences.obs + ) # [B, T, ...] + target_actor_policy = actor_apply_fn( + target_actor_params, sequences.obs + ) # [B, T, ...] + online_q_t = q_apply_fn(online_q_params, sequences.obs, sequence.action) # [B, T] + + # Cast and clip rewards. + discount = 1.0 - sequence.done.astype(jnp.float32) + d_t = (discount * config.system.gamma).astype(jnp.float32) + r_t = jnp.clip( + sequence.reward, -config.system.max_abs_reward, config.system.max_abs_reward + ).astype(jnp.float32) + + # Policy to use for policy evaluation and bootstrapping. + if config.system.use_online_policy_to_bootstrap: + policy_to_evaluate = online_actor_policy + else: + policy_to_evaluate = target_actor_policy + + # Action(s) to use for policy evaluation; shape [N, B, T]. + if config.system.stochastic_policy_eval: + a_evaluation = policy_to_evaluate.sample( + seed=rng_key, sample_shape=config.system.num_samples + ) # [N, B, T, ...] + else: + a_evaluation = policy_to_evaluate.mode()[jnp.newaxis, ...] # [N=1, B, T, ...] + + # Add a stopgrad in case we use the online policy for evaluation. + a_evaluation = jax.lax.stop_gradient(a_evaluation) + + # Compute the Q-values for the next state-action pairs; [N, B, T]. + q_values = jax.vmap(q_apply_fn, in_axes=(None, None, 0))( + target_q_params, sequences.obs, a_evaluation + ) + + # When policy_eval_stochastic == True, this corresponds to expected SARSA. + # Otherwise, the mean is a no-op. + v_t = jnp.mean(q_values, axis=0) # [B, T] + + if config.system.use_retrace: + # Compute the log-rhos for the retrace targets. + log_rhos = target_actor_policy.log_prob(sequences.action) - sequences.log_prob + + # Compute target Q-values + target_q_t = q_apply_fn( + target_q_params, sequences.obs, sequences.action + ) # [B, T] + + # Compute retrace targets. + # These targets use the rewards and discounts as in normal TD-learning but + # they use a mix of bootstrapped values V(s') and Q(s', a'), weighing the + # latter based on how likely a' is under the current policy (s' and a' are + # samples from replay). + # See [Munos et al., 2016](https://arxiv.org/abs/1606.02647) for more. + retrace_error = batch_retrace_continuous( + online_q_t[:, :-1], + target_q_t[:, 1:-1], + v_t[:, 1:], + r_t[:, :-1], + d_t[:, :-1], + log_rhos[:, 1:-1], + config.system.retrace_lambda, + ) + q_loss = rlax.l2_loss(retrace_error).mean() + else: + n_step_value_target = batch_n_step_bootstrapped_returns( + r_t[:, :-1], + d_t[:, :-1], + v_t[:, 1:], + config.system.n_step_for_sequence_bootstrap, + ) + td_error = online_q_t[:, :-1] - n_step_value_target + q_loss = rlax.l2_loss(td_error).mean() + + loss_info = { + "q_loss": q_loss, + } + + return q_loss, loss_info + + params, opt_states, buffer_state, key = update_state + + key, sample_key, actor_key, q_key = jax.random.split(key, num=4) + + # SAMPLE SEQUENCES + sequence_sample = buffer_sample_fn(buffer_state, sample_key) + sequence: SequenceStep = sequence_sample.experience + + # CALCULATE ACTOR AND DUAL LOSS + actor_dual_grad_fn = jax.grad(_actor_loss_fn, argnums=(0, 1), has_aux=True) + actor_dual_grads, actor_loss_info = actor_dual_grad_fn( + params.actor_params.online, + params.dual_params, + params.actor_params.target, + params.q_params.target, + sequence, + actor_key, + ) + + # CALCULATE Q LOSS + q_grad_fn = jax.grad(_q_loss_fn, has_aux=True) + q_grads, q_loss_info = q_grad_fn( + params.q_params.online, + params.q_params.target, + params.actor_params.online, + params.actor_params.target, + sequence, + q_key, + ) + + # Compute the parallel mean (pmean) over the batch. + # This calculation is inspired by the Anakin architecture demo notebook. + # available at https://tinyurl.com/26tdzs5x + # This pmean could be a regular mean as the batch axis is on the same device. + actor_dual_grads, actor_loss_info = jax.lax.pmean( + (actor_dual_grads, actor_loss_info), axis_name="batch" + ) + # pmean over devices. + actor_dual_grads, actor_loss_info = jax.lax.pmean( + (actor_dual_grads, actor_loss_info), axis_name="device" + ) + + q_grads, q_loss_info = jax.lax.pmean((q_grads, q_loss_info), axis_name="batch") + q_grads, q_loss_info = jax.lax.pmean((q_grads, q_loss_info), axis_name="device") + + actor_grads, dual_grads = actor_dual_grads + + # UPDATE ACTOR PARAMS AND OPTIMISER STATE + actor_updates, actor_new_opt_state = actor_update_fn( + actor_grads, opt_states.actor_opt_state + ) + actor_new_online_params = optax.apply_updates(params.actor_params.online, actor_updates) + + # UPDATE DUAL PARAMS AND OPTIMISER STATE + dual_updates, dual_new_opt_state = dual_update_fn(dual_grads, opt_states.dual_opt_state) + dual_new_params = optax.apply_updates(params.dual_params, dual_updates) + dual_new_params = clip_dual_params(dual_new_params, config.system.per_dim_constraining) + + # UPDATE Q PARAMS AND OPTIMISER STATE + q_updates, q_new_opt_state = q_update_fn(q_grads, opt_states.q_opt_state) + q_new_online_params = optax.apply_updates(params.q_params.online, q_updates) + # Target network polyak update. + (new_target_actor_params, new_target_q_params) = optax.incremental_update( + (actor_new_online_params, q_new_online_params), + (params.actor_params.target, params.q_params.target), + config.system.tau, + ) + + actor_new_params = ActorAndTarget(actor_new_online_params, new_target_actor_params) + q_new_params = QsAndTarget(q_new_online_params, new_target_q_params) + + # PACK NEW PARAMS AND OPTIMISER STATE + new_params = MPOParams(actor_new_params, q_new_params, dual_new_params) + new_opt_state = MPOOptStates(actor_new_opt_state, q_new_opt_state, dual_new_opt_state) + + # PACK LOSS INFO + loss_info = actor_loss_info._asdict() + loss_info = { + **loss_info, + "value_loss": q_loss_info["q_loss"], + } + return (new_params, new_opt_state, buffer_state, key), loss_info + + update_state = (params, opt_states, buffer_state, key) + + # UPDATE EPOCHS + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.epochs + ) + + params, opt_states, buffer_state, key = update_state + learner_state = MPOLearnerState( + params, opt_states, buffer_state, key, env_state, last_timestep + ) + metric = traj_batch.info + return learner_state, (metric, loss_info) + + def learner_fn(learner_state: MPOLearnerState) -> ExperimentOutput[MPOLearnerState]: + """Learner function. + + This function represents the learner, it updates the network parameters + by iteratively applying the `_update_step` function for a fixed number of + updates. The `_update_step` function is vectorized over a batch of inputs. + """ + + batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") + + learner_state, (episode_info, loss_info) = jax.lax.scan( + batched_update_step, learner_state, None, config.system.num_updates_per_eval + ) + return ExperimentOutput( + learner_state=learner_state, + episode_metrics=episode_info, + train_metrics=loss_info, + ) + + return learner_fn + + +def learner_setup( + env: Environment, keys: chex.Array, config: DictConfig +) -> Tuple[LearnerFn[MPOLearnerState], Actor, MPOLearnerState]: + """Initialise learner_fn, network, optimiser, environment and states.""" + # Get available TPU cores. + n_devices = len(jax.devices()) + + # Get number of actions or action dimension from the environment. + action_dim = int(env.action_spec().shape[-1]) + config.system.action_dim = action_dim + + # PRNG keys. + key, actor_net_key, q_net_key = keys + + # Define actor_network, q_network and optimiser. + actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) + actor_action_head = hydra.utils.instantiate( + config.network.actor_network.action_head, action_dim=action_dim + ) + actor_network = Actor(torso=actor_torso, action_head=actor_action_head) + q_network_input = hydra.utils.instantiate(config.network.q_network.input_layer) + q_network_torso = hydra.utils.instantiate(config.network.q_network.pre_torso) + q_network_head = hydra.utils.instantiate(config.network.q_network.critic_head) + q_network = CompositeNetwork([q_network_input, q_network_torso, q_network_head]) + + actor_lr = make_learning_rate(config.system.actor_lr, config, config.system.epochs) + q_lr = make_learning_rate(config.system.q_lr, config, config.system.epochs) + + actor_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(actor_lr, eps=1e-5), + ) + q_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(q_lr, eps=1e-5), + ) + + # Initialise observation: Select only obs for a single agent. + init_x = env.observation_spec().generate_value() + init_x = jax.tree_util.tree_map(lambda x: x[None, ...], init_x) + init_a = jnp.zeros((1, action_dim)) + + # Initialise actor params and optimiser state. + actor_params = actor_network.init(actor_net_key, init_x) + target_actor_params = actor_params + actor_opt_state = actor_optim.init(actor_params) + + # Initialise q params and optimiser state. + online_q_params = q_network.init(q_net_key, init_x, init_a) + target_q_params = online_q_params + q_opt_state = q_optim.init(online_q_params) + + # Initialise MPO Dual params and optimiser state. + if config.system.per_dim_constraining: + dual_variable_shape = [action_dim] + else: + dual_variable_shape = [1] + + log_temperature = jnp.full([1], config.system.init_log_temperature, dtype=jnp.float32) + + log_alpha_mean = jnp.full( + dual_variable_shape, config.system.init_log_alpha_mean, dtype=jnp.float32 + ) + + log_alpha_stddev = jnp.full( + dual_variable_shape, config.system.init_log_alpha_stddev, dtype=jnp.float32 + ) + + if config.system.action_penalization: + log_penalty_temperature = jnp.full( + [1], config.system.init_log_temperature, dtype=jnp.float32 + ) + else: + log_penalty_temperature = None + + dual_params = DualParams( + log_temperature=log_temperature, + log_alpha_mean=log_alpha_mean, + log_alpha_stddev=log_alpha_stddev, + log_penalty_temperature=log_penalty_temperature, + ) + + dual_lr = make_learning_rate(config.system.dual_lr, config, config.system.epochs) + dual_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(dual_lr, eps=1e-5), + ) + dual_opt_state = dual_optim.init(dual_params) + + params = MPOParams( + ActorAndTarget(actor_params, target_actor_params), + QsAndTarget(online_q_params, target_q_params), + dual_params, + ) + opt_states = MPOOptStates(actor_opt_state, q_opt_state, dual_opt_state) + + actor_network_apply_fn = actor_network.apply + q_network_apply_fn = q_network.apply + + # Pack apply and update functions. + apply_fns = (actor_network_apply_fn, q_network_apply_fn) + update_fns = (actor_optim.update, q_optim.update, dual_optim.update) + + # Create replay buffer + dummy_sequence_step = SequenceStep( + obs=jax.tree_util.tree_map(lambda x: x.squeeze(0), init_x), + action=jnp.zeros((action_dim), dtype=float), + reward=jnp.zeros((), dtype=float), + done=jnp.zeros((), dtype=bool), + log_prob=jnp.zeros((), dtype=float), + info={"episode_return": 0.0, "episode_length": 0}, + ) + + buffer_fn = fbx.make_trajectory_buffer( + max_size=config.system.buffer_size, + min_length_time_axis=config.system.sample_sequence_length, + sample_batch_size=config.system.batch_size, + sample_sequence_length=config.system.sample_sequence_length, + period=config.system.period, + add_batch_size=config.arch.num_envs, + ) + buffer_fns = (buffer_fn.add, buffer_fn.sample) + buffer_states = buffer_fn.init(dummy_sequence_step) + + # Get batched iterated update and replicate it to pmap it over cores. + learn = get_learner_fn(env, apply_fns, update_fns, buffer_fns, config) + learn = jax.pmap(learn, axis_name="device") + + warmup = get_warmup_fn(env, params, actor_network_apply_fn, buffer_fn.add, config) + warmup = jax.pmap(warmup, axis_name="device") + + # Initialise environment states and timesteps: across devices and batches. + key, *env_keys = jax.random.split( + key, n_devices * config.system.update_batch_size * config.arch.num_envs + 1 + ) + env_states, timesteps = jax.vmap(env.reset, in_axes=(0))( + jnp.stack(env_keys), + ) + reshape_states = lambda x: x.reshape( + (n_devices, config.system.update_batch_size, config.arch.num_envs) + x.shape[1:] + ) + # (devices, update batch size, num_envs, ...) + env_states = jax.tree_map(reshape_states, env_states) + timesteps = jax.tree_map(reshape_states, timesteps) + + # Load model from checkpoint if specified. + if config.logger.checkpointing.load_model: + loaded_checkpoint = Checkpointer( + model_name=config.logger.system_name, + **config.logger.checkpointing.load_args, # Other checkpoint args + ) + # Restore the learner state from the checkpoint + restored_params, _ = loaded_checkpoint.restore_params(TParams=MPOParams) + # Update the params + params = restored_params + + # Define params to be replicated across devices and batches. + key, step_keys, warmup_keys = jax.random.split(key, num=3) + + replicate_learner = (params, opt_states, buffer_states, step_keys, warmup_keys) + + # Duplicate learner for update_batch_size. + broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size,) + x.shape) + replicate_learner = jax.tree_map(broadcast, replicate_learner) + + # Duplicate learner across devices. + replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()) + + # Initialise learner state. + params, opt_states, buffer_states, step_keys, warmup_keys = replicate_learner + # Warmup the buffer. + env_states, timesteps, keys, buffer_states = warmup( + env_states, timesteps, buffer_states, warmup_keys + ) + init_learner_state = MPOLearnerState( + params, opt_states, buffer_states, step_keys, env_states, timesteps + ) + + return learn, actor_network, init_learner_state + + +def run_experiment(_config: DictConfig) -> None: + """Runs experiment.""" + config = copy.deepcopy(_config) + + # Calculate total timesteps. + n_devices = len(jax.devices()) + config = check_total_timesteps(config) + assert ( + config.system.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + + # Create the enviroments for train and eval. + env, eval_env = environments.make(config=config) + + # PRNG keys. + key, key_e, actor_net_key, q_net_key = jax.random.split( + jax.random.PRNGKey(config["system"]["seed"]), num=4 + ) + + # Setup learner. + learn, actor_network, learner_state = learner_setup( + env, (key, actor_net_key, q_net_key), config + ) + + # Setup evaluator. + evaluator, absolute_metric_evaluator, (trained_params, eval_keys) = evaluator_setup( + eval_env=eval_env, + key_e=key_e, + network=actor_network, + params=learner_state.params.actor_params.online, + config=config, + ) + + # Calculate number of updates per evaluation. + config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation + steps_per_rollout = ( + n_devices + * config.system.num_updates_per_eval + * config.system.rollout_length + * config.system.update_batch_size + * config.arch.num_envs + ) + + # Logger setup + logger = StoixLogger(config) + cfg: Dict = OmegaConf.to_container(config, resolve=True) + cfg["arch"]["devices"] = jax.devices() + pprint(cfg) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=config, # Save all config as metadata in the checkpoint + model_name=config.logger.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + # Run experiment for a total number of evaluations. + max_episode_return = jnp.float32(-1e6) + best_params = unreplicate_batch_dim(learner_state.params.actor_params.online) + for eval_step in range(config.arch.num_evaluation): + # Train. + start_time = time.time() + + learner_output = learn(learner_state) + jax.block_until_ready(learner_output) + + # Log the results of the training. + elapsed_time = time.time() - start_time + t = int(steps_per_rollout * (eval_step + 1)) + learner_output.episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + + # Separately log timesteps, actoring metrics and training metrics. + logger.log({"timestep": t}, t, eval_step, LogEvent.MISC) + logger.log(learner_output.episode_metrics, t, eval_step, LogEvent.ACT) + logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) + + # Prepare for evaluation. + start_time = time.time() + trained_params = unreplicate_batch_dim( + learner_output.learner_state.params.actor_params.online + ) # Select only actor params + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + # Evaluate. + evaluator_output = evaluator(trained_params, eval_keys) + jax.block_until_ready(evaluator_output) + + # Log the results of the evaluation. + elapsed_time = time.time() - start_time + episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"]) + + evaluator_output.episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL) + + if save_checkpoint: + checkpointer.save( + timestep=int(steps_per_rollout * (eval_step + 1)), + unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params = copy.deepcopy(trained_params) + max_episode_return = episode_return + + # Update runner state to continue training. + learner_state = learner_output.learner_state + + # Measure absolute metric. + if config.arch.absolute_metric: + start_time = time.time() + + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + evaluator_output = absolute_metric_evaluator(best_params, eval_keys) + jax.block_until_ready(evaluator_output) + + elapsed_time = time.time() - start_time + t = int(steps_per_rollout * (eval_step + 1)) + evaluator_output.episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.ABSOLUTE) + + # Stop the logger. + logger.stop() + + +@hydra.main( + config_path="../../configs", config_name="default_ff_mpo_continuous.yaml", version_base="1.2" +) +def hydra_entry_point(cfg: DictConfig) -> None: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + + # Run experiment. + run_experiment(cfg) + + print(f"{Fore.CYAN}{Style.BRIGHT}MPO experiment completed{Style.RESET_ALL}") + + +if __name__ == "__main__": + hydra_entry_point() diff --git a/stoix/systems/mpo/types.py b/stoix/systems/mpo/types.py new file mode 100644 index 00000000..191b0839 --- /dev/null +++ b/stoix/systems/mpo/types.py @@ -0,0 +1,103 @@ +from typing import Dict, Optional, Union + +import chex +import optax +from flashbax.buffers.trajectory_buffer import BufferState +from flax.core.frozen_dict import FrozenDict +from jumanji.types import TimeStep +from typing_extensions import NamedTuple + +from stoix.systems.q_learning.types import QsAndTarget +from stoix.types import LogEnvState + + +class SequenceStep(NamedTuple): + obs: chex.ArrayTree + action: chex.Array + reward: chex.Array + done: chex.Array + log_prob: chex.Array + info: Dict + + +class ActorAndTarget(NamedTuple): + online: FrozenDict + target: FrozenDict + + +class DualParams(NamedTuple): + log_temperature: chex.Array + log_alpha_mean: chex.Array + log_alpha_stddev: chex.Array + log_penalty_temperature: Optional[chex.Array] = None + + +class CategoricalDualParams(NamedTuple): + log_temperature: chex.Array + log_alpha: chex.Array + + +class MPOParams(NamedTuple): + actor_params: FrozenDict + q_params: QsAndTarget + dual_params: Union[DualParams, CategoricalDualParams] + + +class MPOOptStates(NamedTuple): + actor_opt_state: optax.OptState + q_opt_state: optax.OptState + dual_opt_state: optax.OptState + + +class MPOLearnerState(NamedTuple): + params: MPOParams + opt_states: MPOOptStates + buffer_state: BufferState + key: chex.PRNGKey + env_state: LogEnvState + timestep: TimeStep + + +class MPOStats(NamedTuple): + dual_alpha_mean: Union[float, chex.Array] + dual_alpha_stddev: Union[float, chex.Array] + dual_temperature: Union[float, chex.Array] + + loss_policy: Union[float, chex.Array] + loss_alpha: Union[float, chex.Array] + loss_temperature: Union[float, chex.Array] + kl_q_rel: Union[float, chex.Array] + + kl_mean_rel: Union[float, chex.Array] + kl_stddev_rel: Union[float, chex.Array] + + q_min: Union[float, chex.Array] + q_max: Union[float, chex.Array] + + pi_stddev_min: Union[float, chex.Array] + pi_stddev_max: Union[float, chex.Array] + pi_stddev_cond: Union[float, chex.Array] + + penalty_kl_q_rel: Optional[float] = None + + +class CategoricalMPOStats(NamedTuple): + dual_alpha: float + dual_temperature: float + + loss_e_step: float + loss_m_step: float + loss_dual: float + + loss_policy: float + loss_alpha: float + loss_temperature: float + + kl_q_rel: float + kl_mean_rel: float + + q_min: float + q_max: float + + entropy_online: float + entropy_target: float diff --git a/stoix/systems/sac/types.py b/stoix/systems/sac/types.py index c8b77154..c9a44d6c 100644 --- a/stoix/systems/sac/types.py +++ b/stoix/systems/sac/types.py @@ -13,11 +13,6 @@ ContinuousQApply = Callable[[FrozenDict, Observation, Action], Value] -class Qs(NamedTuple): - q1: FrozenDict - q2: FrozenDict - - class SACParams(NamedTuple): actor_params: FrozenDict q_params: QsAndTarget diff --git a/stoix/utils/multistep.py b/stoix/utils/multistep.py index d2d5b995..ad120a18 100644 --- a/stoix/utils/multistep.py +++ b/stoix/utils/multistep.py @@ -1,9 +1,13 @@ -from typing import Tuple +from typing import Tuple, Union import chex import jax import jax.numpy as jnp +# These functions are generally taken from rlax but edited to explictly take in a batch of data. +# This is because the original rlax functions are not batched and are meant to be used with vmap, +# which can be much slower. + def calculate_gae( v_t: chex.Array, @@ -35,3 +39,169 @@ def _get_advantages( unroll=16, ) return advantages, advantages + v_t + + +def batch_n_step_bootstrapped_returns( + r_t: chex.Array, + discount_t: chex.Array, + v_t: chex.Array, + n: int, + lambda_t: float = 1.0, + stop_target_gradients: bool = True, +) -> chex.Array: + """Computes strided n-step bootstrapped return targets over a batch of sequences. + + The returns are computed according to the below equation iterated `n` times: + + Gₜ = rₜ₊₁ + γₜ₊₁ [(1 - λₜ₊₁) vₜ₊₁ + λₜ₊₁ Gₜ₊₁]. + + When lambda_t == 1. (default), this reduces to + + Gₜ = rₜ₊₁ + γₜ₊₁ * (rₜ₊₂ + γₜ₊₂ * (... * (rₜ₊ₙ + γₜ₊ₙ * vₜ₊ₙ ))). + + Args: + r_t: rewards at times B x [1, ..., T]. + discount_t: discounts at times B x [1, ..., T]. + v_t: state or state-action values to bootstrap from at time B x [1, ...., T]. + n: number of steps over which to accumulate reward before bootstrapping. + lambda_t: lambdas at times B x [1, ..., T]. Shape is [], or B x [T-1]. + stop_target_gradients: bool indicating whether or not to apply stop gradient + to targets. + + Returns: + estimated bootstrapped returns at times B x [0, ...., T-1] + """ + # swap axes to make time axis the first dimension + r_t, discount_t, v_t = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), (r_t, discount_t, v_t)) + seq_len = r_t.shape[0] + batch_size = r_t.shape[1] + + # Maybe change scalar lambda to an array. + lambda_t = jnp.ones_like(discount_t) * lambda_t + + # Shift bootstrap values by n and pad end of sequence with last value v_t[-1]. + pad_size = min(n - 1, seq_len) + targets = jnp.concatenate([v_t[n - 1 :], jnp.array([v_t[-1]] * pad_size)], axis=0) + + # Pad sequences. Shape is now (T + n - 1,). + r_t = jnp.concatenate([r_t, jnp.zeros((n - 1, batch_size))], axis=0) + discount_t = jnp.concatenate([discount_t, jnp.ones((n - 1, batch_size))], axis=0) + lambda_t = jnp.concatenate([lambda_t, jnp.ones((n - 1, batch_size))], axis=0) + v_t = jnp.concatenate([v_t, jnp.array([v_t[-1]] * (n - 1))], axis=0) + + # Work backwards to compute n-step returns. + for i in reversed(range(n)): + r_ = r_t[i : i + seq_len] + discount_ = discount_t[i : i + seq_len] + lambda_ = lambda_t[i : i + seq_len] + v_ = v_t[i : i + seq_len] + targets = r_ + discount_ * ((1.0 - lambda_) * v_ + lambda_ * targets) + + targets = jnp.swapaxes(targets, 0, 1) + return jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(targets), targets) + + +def batch_general_off_policy_returns_from_q_and_v( + q_t: chex.Array, + v_t: chex.Array, + r_t: chex.Array, + discount_t: chex.Array, + c_t: chex.Array, + stop_target_gradients: bool = False, +) -> chex.Array: + """Calculates targets for various off-policy evaluation algorithms. + + Given a window of experience of length `K+1`, generated by a behaviour policy + μ, for each time-step `t` we can estimate the return `G_t` from that step + onwards, under some target policy π, using the rewards in the trajectory, the + values under π of states and actions selected by μ, according to equation: + + Gₜ = rₜ₊₁ + γₜ₊₁ * (vₜ₊₁ - cₜ₊₁ * q(aₜ₊₁) + cₜ₊₁* Gₜ₊₁), + + where, depending on the choice of `c_t`, the algorithm implements: + + Importance Sampling c_t = π(x_t, a_t) / μ(x_t, a_t), + Harutyunyan's et al. Q(lambda) c_t = λ, + Precup's et al. Tree-Backup c_t = π(x_t, a_t), + Munos' et al. Retrace c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)). + + See "Safe and Efficient Off-Policy Reinforcement Learning" by Munos et al. + (https://arxiv.org/abs/1606.02647). + + Args: + q_t: Q-values under π of actions executed by μ at times [1, ..., K - 1]. + v_t: Values under π at times [1, ..., K]. + r_t: rewards at times [1, ..., K]. + discount_t: discounts at times [1, ..., K]. + c_t: weights at times [1, ..., K - 1]. + stop_target_gradients: bool indicating whether or not to apply stop gradient + to targets. + + Returns: + Off-policy estimates of the generalized returns from states visited at times + [0, ..., K - 1]. + """ + q_t, v_t, r_t, discount_t, c_t = jax.tree_map( + lambda x: jnp.swapaxes(x, 0, 1), (q_t, v_t, r_t, discount_t, c_t) + ) + + g = r_t[-1] + discount_t[-1] * v_t[-1] # G_K-1. + + def _body( + acc: chex.Array, xs: Tuple[chex.Array, chex.Array, chex.Array, chex.Array, chex.Array] + ) -> Tuple[chex.Array, chex.Array]: + reward, discount, c, v, q = xs + acc = reward + discount * (v - c * q + c * acc) + return acc, acc + + _, returns = jax.lax.scan( + _body, g, (r_t[:-1], discount_t[:-1], c_t, v_t[:-1], q_t), reverse=True + ) + returns = jnp.concatenate([returns, g[jnp.newaxis]], axis=0) + + returns = jnp.swapaxes(returns, 0, 1) + return jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(returns), returns) + + +def batch_retrace_continuous( + q_tm1: chex.Array, + q_t: chex.Array, + v_t: chex.Array, + r_t: chex.Array, + discount_t: chex.Array, + log_rhos: chex.Array, + lambda_: Union[chex.Array, float], + stop_target_gradients: bool = True, +) -> chex.Array: + """Retrace continuous. + + See "Safe and Efficient Off-Policy Reinforcement Learning" by Munos et al. + (https://arxiv.org/abs/1606.02647). + + Args: + q_tm1: Q-values at times [0, ..., K - 1]. + q_t: Q-values evaluated at actions collected using behavior + policy at times [1, ..., K - 1]. + v_t: Value estimates of the target policy at times [1, ..., K]. + r_t: reward at times [1, ..., K]. + discount_t: discount at times [1, ..., K]. + log_rhos: Log importance weight pi_target/pi_behavior evaluated at actions + collected using behavior policy [1, ..., K - 1]. + lambda_: scalar or a vector of mixing parameter lambda. + stop_target_gradients: bool indicating whether or not to apply stop gradient + to targets. + + Returns: + Retrace error. + """ + + c_t = jnp.minimum(1.0, jnp.exp(log_rhos)) * lambda_ + + # The generalized returns are independent of Q-values and cs at the final + # state. + target_tm1 = batch_general_off_policy_returns_from_q_and_v(q_t, v_t, r_t, discount_t, c_t) + + target_tm1 = jax.lax.select( + stop_target_gradients, jax.lax.stop_gradient(target_tm1), target_tm1 + ) + return target_tm1 - q_tm1 diff --git a/stoix/wrappers/jumanji.py b/stoix/wrappers/jumanji.py index ff177a4d..5696e277 100644 --- a/stoix/wrappers/jumanji.py +++ b/stoix/wrappers/jumanji.py @@ -36,7 +36,7 @@ def __init__( def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: state, timestep = self._env.reset(key) - obs = timestep.observation._asdict()[self._observation_attribute] + obs = timestep.observation._asdict()[self._observation_attribute].astype(jnp.float32) timestep = timestep.replace( observation=Observation( obs.reshape(self._obs_shape), self._legal_action_mask, state.step_count @@ -47,7 +47,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: state, timestep = self._env.step(state, action) - obs = timestep.observation._asdict()[self._observation_attribute] + obs = timestep.observation._asdict()[self._observation_attribute].astype(jnp.float32) timestep = timestep.replace( observation=Observation( obs.reshape(self._obs_shape), self._legal_action_mask, state.step_count diff --git a/stoix/wrappers/truncation.py b/stoix/wrappers/truncation.py index 744ed895..4da99ce0 100644 --- a/stoix/wrappers/truncation.py +++ b/stoix/wrappers/truncation.py @@ -42,9 +42,9 @@ def _auto_reset(self, state: State, timestep: TimeStep) -> Tuple[State, TimeStep extras["final_observation"] = timestep.observation # Replace observation with reset observation. - timestep = timestep.replace( # type: ignore + timestep = timestep.replace( observation=reset_timestep.observation, extras=extras - ) + ) # type: ignore return state, timestep