diff --git a/doc/images/mjx/apg_diagram.png b/doc/images/mjx/apg_diagram.png
new file mode 100644
index 0000000000..961798da7c
Binary files /dev/null and b/doc/images/mjx/apg_diagram.png differ
diff --git a/mjx/training_apg.ipynb b/mjx/training_apg.ipynb
new file mode 100644
index 0000000000..d7cfac89d2
--- /dev/null
+++ b/mjx/training_apg.ipynb
@@ -0,0 +1,1426 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Policy learning and Policy Gradients\n",
+ "\n",
+ "This is a recap of policy learning contextualizes how we can use MJX's differentiability for policy learning. If the below concepts are unfamiliar, there are many great resources [online](https://spinningup.openai.com/en/latest/spinningup/rl_intro.html)! \n",
+ "\n",
+ "The goal of policy learning is to find a control policy $\\pi$ which outputs actions $a_t \\sim \\pi(\\cdot| x_t, \\theta)$ maximizing the total rewards $\\sum r_t$ over some time period, where $r_t$ is shorthand for a reward function evaluated at the state and action of time t:\n",
+ "\n",
+ "$$r_t = r(x_t, a_t)$$\n",
+ "\n",
+ "$\\theta$ are the parameters of the policy; weights in the common case that the policy is a neural network. **Policy gradient methods** involve estimating the gradient of the rewards with respect to the weights, and using this value in a first-order optimization algorithm such as Gradient Descent or [Adam](https://arxiv.org/abs/1412.6980). How we estimate the policy gradient depends on what state transition model we assume. \n",
+ "\n",
+ "#### Zeroth-Order Policy Gradients (ZoPG)\n",
+ "\n",
+ "Referring to `mjx.step` as the simulation function f, we borrow some [terminology](https://arxiv.org/abs/2202.00817) to differentiate between zeroth-order gradients, which only depend on values of f, and first-order gradients, which depend on its jacobian.\n",
+ "\n",
+ "Reinforcement learning (RL) algorithms such as the standard [PPO](https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py) assume the stochastic state transition model $x_{t+1} \\sim P(\\cdot | x_t, a_t)$. This leads to a ZoPG of the form:\n",
+ "\n",
+ "$$\n",
+ "\\nabla_\\theta J(\\pi_\\theta) = \\mathbb{E}_{\\tau \\sim \\pi_\\theta}\\left[ \\sum \\nabla_\\theta \\log\\pi_\\theta (a_t | s_t) R(\\tau) \\right]\n",
+ "$$\n",
+ "\n",
+ "Where $R(\\tau)$ is some function depending on the rollout $\\tau = \\{x_t, a_t\\}_{t=0}^{T}$. Despite this method's popularity and extensive research into its refinement, a fundamental property is that the gradient has high variance. This allows the optimizer to thoroughly explore the space of policies, leading to the robust and often surprisingly good policies that have been achieved. However, the variance comes at the cost of requiring many samples $(x_t, a_t)$ to converge.\n",
+ "\n",
+ "#### First-Order Policy Gradients (FoPG)\n",
+ "On the other hand, if you assume a deterministic state transition model $x_{t+1} = f(x_t, a_t)$, you end up with the first-order policy gradient. Other common names include Analytical Policy Gradients (APG) and Backpropogation through Time (BPTT). Unlike ZoPG methods, which model the state evolution as a probabilistic black box, the FoPG explicitly contains the jacobians of the simulation function f. For example, let's look at the gradient of the reward $r_t$, in the case that it only depends on state.\n",
+ "$$\n",
+ "\\frac{\\partial r_t}{\\partial \\theta} = \\frac{\\partial r_t}{\\partial x_t}\\frac{\\partial x_t}{\\partial \\theta} \n",
+ "$$\n",
+ "\n",
+ "$$\n",
+ "\\frac{\\partial x_t}{\\partial \\theta} = \\textcolor{Navy}{\\frac{\\partial f(x_{t-1}, a_{t-1})}{\\partial x_{t-1}}}\\frac{\\partial x_{t-1}}{\\partial \\theta} + \\textcolor{Navy}{\\frac{\\partial f(x_{t-1}, a_{t-1})}{\\partial a_{t-1}}} \\frac{\\partial a_{t-1}}{\\partial \\theta}\n",
+ "$$\n",
+ "\n",
+ "The navy-colored terms in the above expression are enabled by MJX's differentiability and are the key difference between FoPG's and ZoPG's. An important consideration is what these jacobians look like near contact points. To see why certain gradients within the jacobian can be pathological, imagine a hard sphere falling toward a block of marble. How does its velocity change with respect to distance ($\\frac{\\partial \\dot{z}_t}{\\partial z_t}$), the instant before it touches the ground? This is the case of an **uninformative gradient**, due to [hard contact](https://arxiv.org/html/2404.02887v1). Fortunately, the default contact settings in Mujoco are sufficiently [soft](https://mujoco.readthedocs.io/en/stable/computation/index.html#soft-contact-model) for learning via FoPG's. With soft contacts, the ground applies an increasing force on the ball as it penetrates it, unlike rigid contacts, which instantly provide enough force for deflection.\n",
+ "\n",
+ "A helpful way to think about FoPG's is via the chain rule and computation graphs, as illustrated below for how $r_2$ influences the policy gradient, again for the case that the reward does not depend on action:\n",
+ "\n",
+ " \n",
+ "\n",
+ "Note that there three distinct gradient chains in this example. The red pathway considers how the immediately prior action affected the state. The blue path explains the name *Backpropogation through Time*, capturing how actions affect downstream rewards. The least intuitive may be the green chain, which shows how the reward depends on how actions depend on previous actions.Experience shows that blocking *any* of these three pathways via jax.lax.stop_grad can badly hinder policy learning. As the length of $x_t$ backbone increases, [gradient explosion](https://arxiv.org/abs/2111.05803) becomes a crucial consideration. In practice, this can be resolved via decaying downstream gradients or periodically truncating the gradient.\n",
+ "\n",
+ "**The Sharp Bits of FoPG's**\n",
+ "\n",
+ "While FoPG's have been shown to be very sample efficient, especially as the [dimension of the state space increases](https://arxiv.org/abs/2204.07137), one fundamental shortcoming is that due to the lower gradient variance, FoPG's also have less exploration power than ZoPG's and benefit from the practioner being more explicit in the problem formulation.\n",
+ "\n",
+ "Additionally, discontinuous reward formulations are ubiquitious in RL, for instance, a large penalty when the robot falls. It can be significantly more [challenging](https://arxiv.org/abs/2403.14864) to design robust policies with FoPG's, since they cannot backprop through such penalties.\n",
+ "\n",
+ "Last, despite the sample efficiency, FoPG methods can still struggle with wall-clock time. Because the gradients have low variance, they do not benefit significantly from massive parallelization of data collection - unlike [RL](https://arxiv.org/abs/2109.11978). Additionally, the policy gradient is typically calculated via autodifferentiation. This can be 3-5x slower than unrolling the simulation forward, and memory intensive, with memory requirements scaling with $O(m \\cdot (m+n) \\cdot T)$, where m and n are the state and control dimensions, $m \\cdot (m+n)$ is the jacobian dimension, and T is the number of steps propogated through.\n",
+ "\n",
+ "Note that with certain models, using autodifferentiation through mjx.step currently causes [nan gradients](https://github.com/google-deepmind/mujoco/issues/1517). For now, we address this issue by using double-precision floats, at the cost of doubling the memory requirements and training time."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "---\n",
+ "\n",
+ "**Coming Up**\n",
+ "\n",
+ "In this tutorial, we demonstrate two ways to use FoPG's, using Brax's simple APG [algorithm](https://github.com/google/brax/tree/main/brax/training/agents/apg). This algorithm essentially uses FoPG's to perform live gradient descent on the policy, unrolling it for a short window, using the data to do a policy update, then continuing where it left off."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup: Imports and installations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"0.8\" # 0.9 causes too much lag. \n",
+ "from datetime import datetime\n",
+ "import functools\n",
+ "\n",
+ "# Math\n",
+ "import jax.numpy as jp\n",
+ "import numpy as np\n",
+ "import jax\n",
+ "from jax import config # Analytical gradients work much better with double precision.\n",
+ "config.update(\"jax_debug_nans\", True)\n",
+ "config.update(\"jax_enable_x64\", True)\n",
+ "config.update('jax_default_matmul_precision', jax.lax.Precision.HIGH)\n",
+ "from brax import math\n",
+ "\n",
+ "# Sim\n",
+ "import mujoco\n",
+ "import mujoco.mjx as mjx\n",
+ "\n",
+ "# Brax\n",
+ "from brax import envs\n",
+ "from brax.base import Motion, Transform\n",
+ "from brax.io import mjcf\n",
+ "from brax.envs.base import PipelineEnv, State\n",
+ "from brax.mjx.pipeline import _reformat_contact\n",
+ "from brax.training.acme import running_statistics\n",
+ "from brax.io import model\n",
+ "\n",
+ "# Algorithms\n",
+ "from brax.training.agents.apg import train as apg\n",
+ "from brax.training.agents.apg import networks as apg_networks\n",
+ "from brax.training.agents.ppo import train as ppo\n",
+ "\n",
+ "# Supporting\n",
+ "from etils import epath\n",
+ "import mediapy as media\n",
+ "import matplotlib.pyplot as plt\n",
+ "from ml_collections import config_dict\n",
+ "from typing import Any, Dict\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Quadruped Env"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!git clone https://github.com/google-deepmind/mujoco_menagerie.git"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "xml_path = epath.Path('mujoco_menagerie/anybotics_anymal_c/scene_mjx.xml').as_posix()\n",
+ "\n",
+ "mj_model = mujoco.MjModel.from_xml_path(xml_path)\n",
+ "\n",
+ "if 'renderer' not in dir():\n",
+ " renderer = mujoco.Renderer(mj_model)\n",
+ "\n",
+ "init_q = mj_model.keyframe('standing').qpos\n",
+ "\n",
+ "mj_data = mujoco.MjData(mj_model)\n",
+ "mj_data.qpos = init_q\n",
+ "mujoco.mj_forward(mj_model, mj_data)\n",
+ "\n",
+ "renderer.update_scene(mj_data)\n",
+ "media.show_image(renderer.render())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Rendering Rollouts\n",
+ "def render_rollout(reset_fn, step_fn, \n",
+ " inference_fn, env, \n",
+ " n_steps = 200, camera=None,\n",
+ " seed=0):\n",
+ " rng = jax.random.key(seed)\n",
+ " render_every = 3\n",
+ " state = reset_fn(rng)\n",
+ " rollout = [state.pipeline_state]\n",
+ "\n",
+ " for i in range(n_steps):\n",
+ " act_rng, rng = jax.random.split(rng)\n",
+ " ctrl, _ = inference_fn(state.obs, act_rng)\n",
+ " state = step_fn(state, ctrl)\n",
+ " if i % render_every == 0:\n",
+ " rollout.append(state.pipeline_state)\n",
+ "\n",
+ " media.show_video(env.render(rollout, camera=camera), \n",
+ " fps=1.0 / (env.dt*render_every),\n",
+ " codec='gif')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Study 1: Imitating Kinematics\n",
+ "\n",
+ "FoPG's have been demonstrated to work well in [imitation learning](https://openreview.net/forum?id=06mk-epSwZ), especially when the agent state is reset to track the reference state when it gets too far. In this section, we learn to trot in place. Due to the limited gain on the PD controllers, trotting quickly is far from trivial!\n",
+ "\n",
+ "The RL environment has three rewards:\n",
+ "- min_reference_tracking penalizes error from the reference motion, in minimal coordinates. This makes the policy's output more precise.\n",
+ "- reference_tracking penalizes error in maximal coordinates and increases training stability.\n",
+ "- feet_height nudges the balance of which body positions and velocities to track, by placing additional incentive on the *position* of the feet."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Designing Reference Kinematics"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def cos_wave(t, step_period, scale):\n",
+ " _cos_wave = -jp.cos(((2*jp.pi)/step_period)*t)\n",
+ " return _cos_wave * (scale/2) + (scale/2)\n",
+ "\n",
+ "def dcos_wave(t, step_period, scale):\n",
+ " \"\"\" \n",
+ " Derivative of the cos wave, for reference velocity\n",
+ " \"\"\"\n",
+ " return ((scale*jp.pi) / step_period) * jp.sin(((2*jp.pi)/step_period)*t)\n",
+ "\n",
+ "def make_kinematic_ref(sinusoid, step_k, scale=0.3, dt=1/50):\n",
+ " \"\"\" \n",
+ " Makes trotting kinematics for the 12 leg joints.\n",
+ " step_k is the number of timesteps it takes to raise and lower a given foot.\n",
+ " A gait cycle is 2 * step_k * dt seconds long.\n",
+ " \"\"\"\n",
+ " \n",
+ " _steps = jp.arange(step_k)\n",
+ " step_period = step_k * dt\n",
+ " t = _steps * dt\n",
+ " \n",
+ " wave = sinusoid(t, step_period, scale)\n",
+ " # Commands for one step of an active front leg\n",
+ " fleg_cmd_block = jp.concatenate(\n",
+ " [jp.zeros((step_k, 1)),\n",
+ " wave.reshape(step_k, 1),\n",
+ " -2*wave.reshape(step_k, 1)],\n",
+ " axis=1\n",
+ " )\n",
+ " # Our standing config reverses front and hind legs\n",
+ " h_leg_cmd_bloc = -1 * fleg_cmd_block\n",
+ "\n",
+ " block1 = jp.concatenate([\n",
+ " jp.zeros((step_k, 3)),\n",
+ " fleg_cmd_block,\n",
+ " h_leg_cmd_bloc,\n",
+ " jp.zeros((step_k, 3))],\n",
+ " axis=1\n",
+ " )\n",
+ "\n",
+ " block2 = jp.concatenate([\n",
+ " fleg_cmd_block,\n",
+ " jp.zeros((step_k, 3)),\n",
+ " jp.zeros((step_k, 3)),\n",
+ " h_leg_cmd_bloc],\n",
+ " axis=1\n",
+ " )\n",
+ " # In one step cycle, both pairs of active legs have inactive and active phases\n",
+ " step_cycle = jp.concatenate([block1, block2], axis=0)\n",
+ " return step_cycle\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "poses = make_kinematic_ref(cos_wave, step_k=25)\n",
+ "\n",
+ "frames = []\n",
+ "init_q = mj_model.keyframe('standing').qpos\n",
+ "mj_data.qpos = init_q\n",
+ "default_ap = init_q[7:]\n",
+ "\n",
+ "for i in range(len(poses)):\n",
+ " mj_data.qpos[7:] = poses[i] + default_ap\n",
+ " mujoco.mj_forward(mj_model, mj_data)\n",
+ " renderer.update_scene(mj_data)\n",
+ " frames.append(renderer.render())\n",
+ "\n",
+ "media.show_video(frames, fps=50, codec='gif')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### The RL Environment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_config():\n",
+ " def get_default_rewards_config():\n",
+ " default_config = config_dict.ConfigDict(\n",
+ " dict(\n",
+ " scales=config_dict.ConfigDict(\n",
+ " dict(\n",
+ " min_reference_tracking = -2.5 * 3e-3, # to equalize the magnitude\n",
+ " reference_tracking = -1.0,\n",
+ " feet_height = -1.0\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " return default_config\n",
+ "\n",
+ " default_config = config_dict.ConfigDict(\n",
+ " dict(rewards=get_default_rewards_config(),))\n",
+ "\n",
+ " return default_config\n",
+ "\n",
+ "# Math functions from (https://github.com/jiawei-ren/diffmimic)\n",
+ "def quaternion_to_matrix(quaternions):\n",
+ " r, i, j, k = quaternions[..., 0], quaternions[..., 1], quaternions[..., 2], quaternions[..., 3]\n",
+ " two_s = 2.0 / (quaternions * quaternions).sum(-1)\n",
+ "\n",
+ " o = jp.stack(\n",
+ " (\n",
+ " 1 - two_s * (j * j + k * k),\n",
+ " two_s * (i * j - k * r),\n",
+ " two_s * (i * k + j * r),\n",
+ " two_s * (i * j + k * r),\n",
+ " 1 - two_s * (i * i + k * k),\n",
+ " two_s * (j * k - i * r),\n",
+ " two_s * (i * k - j * r),\n",
+ " two_s * (j * k + i * r),\n",
+ " 1 - two_s * (i * i + j * j),\n",
+ " ),\n",
+ " -1,\n",
+ " )\n",
+ " return o.reshape(quaternions.shape[:-1] + (3, 3))\n",
+ "\n",
+ "def matrix_to_rotation_6d(matrix):\n",
+ " batch_dim = matrix.shape[:-2]\n",
+ " return matrix[..., :2, :].reshape(batch_dim + (6,))\n",
+ "\n",
+ "def quaternion_to_rotation_6d(quaternion):\n",
+ " return matrix_to_rotation_6d(quaternion_to_matrix(quaternion))\n",
+ "\n",
+ "class TrotAnymal(PipelineEnv):\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " termination_height: float=0.25,\n",
+ " **kwargs,\n",
+ " ):\n",
+ " step_k = kwargs.pop('step_k', 25)\n",
+ "\n",
+ " physics_steps_per_control_step = 10\n",
+ " kwargs['n_frames'] = kwargs.get(\n",
+ " 'n_frames', physics_steps_per_control_step)\n",
+ "\n",
+ " mj_model = mujoco.MjModel.from_xml_path(xml_path)\n",
+ " kp = 230\n",
+ " mj_model.actuator_gainprm[:, 0] = kp\n",
+ " mj_model.actuator_biasprm[:, 1] = -kp\n",
+ "\n",
+ " sys = mjcf.load_model(mj_model)\n",
+ "\n",
+ " super().__init__(sys=sys, **kwargs) \n",
+ " \n",
+ " self.termination_height = termination_height\n",
+ " \n",
+ " self._init_q = mj_model.keyframe('standing').qpos\n",
+ " \n",
+ " self.err_threshold = 0.4 # diffmimic; value from paper.\n",
+ " \n",
+ " self._default_ap_pose = mj_model.keyframe('standing').qpos[7:]\n",
+ " self.reward_config = get_config()\n",
+ "\n",
+ " self.action_loc = self._default_ap_pose\n",
+ " self.action_scale = jp.array([0.2, 0.8, 0.8] * 4)\n",
+ " \n",
+ " self.feet_inds = jp.array([21,28,35,42]) # LF, RF, LH, RH\n",
+ "\n",
+ " #### Imitation reference\n",
+ " kinematic_ref_qpos = make_kinematic_ref(\n",
+ " cos_wave, step_k, scale=0.3, dt=self.dt)\n",
+ " kinematic_ref_qvel = make_kinematic_ref(\n",
+ " dcos_wave, step_k, scale=0.3, dt=self.dt)\n",
+ " \n",
+ " self.l_cycle = jp.array(kinematic_ref_qpos.shape[0])\n",
+ " \n",
+ " # Expand to entire state space.\n",
+ "\n",
+ " kinematic_ref_qpos += self._default_ap_pose\n",
+ " ref_qs = np.tile(self._init_q.reshape(1, 19), (self.l_cycle, 1))\n",
+ " ref_qs[:, 7:] = kinematic_ref_qpos\n",
+ " self.kinematic_ref_qpos = jp.array(ref_qs)\n",
+ " \n",
+ " ref_qvels = np.zeros((self.l_cycle, 18))\n",
+ " ref_qvels[:, 6:] = kinematic_ref_qvel\n",
+ " self.kinematic_ref_qvel = jp.array(ref_qvels)\n",
+ "\n",
+ " # Can decrease jit time and training wall-clock time significantly.\n",
+ " self.pipeline_step = jax.checkpoint(self.pipeline_step, \n",
+ " policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)\n",
+ " \n",
+ " def reset(self, rng: jax.Array) -> State:\n",
+ " # Deterministic init\n",
+ "\n",
+ " qpos = jp.array(self._init_q)\n",
+ " qvel = jp.zeros(18)\n",
+ " \n",
+ " data = self.pipeline_init(qpos, qvel)\n",
+ "\n",
+ " # Position onto ground\n",
+ " pen = jp.min(data.contact.dist)\n",
+ " qpos = qpos.at[2].set(qpos[2] - pen)\n",
+ " data = self.pipeline_init(qpos, qvel)\n",
+ "\n",
+ " state_info = {\n",
+ " 'rng': rng,\n",
+ " 'steps': 0.0,\n",
+ " 'reward_tuple': {\n",
+ " 'reference_tracking': 0.0,\n",
+ " 'min_reference_tracking': 0.0,\n",
+ " 'feet_height': 0.0\n",
+ " },\n",
+ " 'last_action': jp.zeros(12), # from MJX tutorial.\n",
+ " 'kinematic_ref': jp.zeros(19),\n",
+ " }\n",
+ "\n",
+ " x, xd = data.x, data.xd\n",
+ " obs = self._get_obs(data.qpos, x, xd, state_info)\n",
+ " reward, done = jp.zeros(2)\n",
+ " metrics = {}\n",
+ " for k in state_info['reward_tuple']:\n",
+ " metrics[k] = state_info['reward_tuple'][k]\n",
+ " state = State(data, obs, reward, done, metrics, state_info)\n",
+ " return jax.lax.stop_gradient(state)\n",
+ " \n",
+ " def step(self, state: State, action: jax.Array) -> State:\n",
+ " action = jp.clip(action, -1, 1) # Raw action\n",
+ "\n",
+ " action = self.action_loc + (action * self.action_scale)\n",
+ "\n",
+ " data = self.pipeline_step(state.pipeline_state, action)\n",
+ " \n",
+ " ref_qpos = self.kinematic_ref_qpos[jp.array(state.info['steps']%self.l_cycle, int)]\n",
+ " ref_qvel = self.kinematic_ref_qvel[jp.array(state.info['steps']%self.l_cycle, int)]\n",
+ " \n",
+ " # Calculate maximal coordinates\n",
+ " ref_data = data.replace(qpos=ref_qpos, qvel=ref_qvel)\n",
+ " ref_data = mjx.forward(self.sys, ref_data)\n",
+ " ref_x, ref_xd = ref_data.x, ref_data.xd\n",
+ "\n",
+ " state.info['kinematic_ref'] = ref_qpos\n",
+ "\n",
+ " # observation data\n",
+ " x, xd = data.x, data.xd\n",
+ " obs = self._get_obs(data.qpos, x, xd, state.info)\n",
+ "\n",
+ " # Terminate if flipped over or fallen down.\n",
+ " done = 0.0\n",
+ " done = jp.where(x.pos[0, 2] < self.termination_height, 1.0, done)\n",
+ " up = jp.array([0.0, 0.0, 1.0])\n",
+ " done = jp.where(jp.dot(math.rotate(up, x.rot[0]), up) < 0, 1.0, done)\n",
+ "\n",
+ " # reward\n",
+ " reward_tuple = {\n",
+ " 'reference_tracking': (\n",
+ " self._reward_reference_tracking(x, xd, ref_x, ref_xd)\n",
+ " * self.reward_config.rewards.scales.reference_tracking\n",
+ " ),\n",
+ " 'min_reference_tracking': (\n",
+ " self._reward_min_reference_tracking(ref_qpos, ref_qvel, state)\n",
+ " * self.reward_config.rewards.scales.min_reference_tracking\n",
+ " ),\n",
+ " 'feet_height': (\n",
+ " self._reward_feet_height(data.geom_xpos[self.feet_inds][:, 2]\n",
+ " ,ref_data.geom_xpos[self.feet_inds][:, 2])\n",
+ " * self.reward_config.rewards.scales.feet_height\n",
+ " )\n",
+ " }\n",
+ " \n",
+ " reward = sum(reward_tuple.values())\n",
+ "\n",
+ " # state management\n",
+ " state.info['reward_tuple'] = reward_tuple\n",
+ " state.info['last_action'] = action # used for observation. \n",
+ "\n",
+ " for k in state.info['reward_tuple'].keys():\n",
+ " state.metrics[k] = state.info['reward_tuple'][k]\n",
+ "\n",
+ " state = state.replace(\n",
+ " pipeline_state=data, obs=obs, reward=reward,\n",
+ " done=done)\n",
+ " \n",
+ " #### Reset state to reference if it gets too far\n",
+ " error = (((x.pos - ref_x.pos) ** 2).sum(-1)**0.5).mean()\n",
+ " to_reference = jp.where(error > self.err_threshold, 1.0, 0.0)\n",
+ "\n",
+ " to_reference = jp.array(to_reference, dtype=int) # keeps output types same as input. \n",
+ " ref_data = self.mjx_to_brax(ref_data)\n",
+ "\n",
+ " data = jax.tree_util.tree_map(lambda x, y: \n",
+ " jp.array((1-to_reference)*x + to_reference*y, x.dtype), data, ref_data)\n",
+ " \n",
+ " x, xd = data.x, data.xd # Data may have changed.\n",
+ " obs = self._get_obs(data.qpos, x, xd, state.info)\n",
+ " \n",
+ " return state.replace(pipeline_state=data, obs=obs)\n",
+ " \n",
+ " def _get_obs(self, qpos: jax.Array, x: Transform, xd: Motion,\n",
+ " state_info: Dict[str, Any]) -> jax.Array:\n",
+ "\n",
+ " inv_base_orientation = math.quat_inv(x.rot[0])\n",
+ " local_rpyrate = math.rotate(xd.ang[0], inv_base_orientation)\n",
+ "\n",
+ " obs_list = []\n",
+ " # yaw rate\n",
+ " obs_list.append(jp.array([local_rpyrate[2]]) * 0.25)\n",
+ " # projected gravity\n",
+ " obs_list.append(\n",
+ " math.rotate(jp.array([0.0, 0.0, -1.0]), inv_base_orientation))\n",
+ " # motor angles\n",
+ " angles = qpos[7:19]\n",
+ " obs_list.append(angles - self._default_ap_pose)\n",
+ " # last action\n",
+ " obs_list.append(state_info['last_action'])\n",
+ " # kinematic reference\n",
+ " kin_ref = self.kinematic_ref_qpos[jp.array(state_info['steps']%self.l_cycle, int)]\n",
+ " obs_list.append(kin_ref[7:]) # First 7 indicies are fixed\n",
+ "\n",
+ " obs = jp.clip(jp.concatenate(obs_list), -100.0, 100.0)\n",
+ "\n",
+ " return obs\n",
+ " \n",
+ " def mjx_to_brax(self, data):\n",
+ " \"\"\" \n",
+ " Apply the brax wrapper on the core MJX data structure.\n",
+ " \"\"\"\n",
+ " q, qd = data.qpos, data.qvel\n",
+ " x = Transform(pos=data.xpos[1:], rot=data.xquat[1:])\n",
+ " cvel = Motion(vel=data.cvel[1:, 3:], ang=data.cvel[1:, :3])\n",
+ " offset = data.xpos[1:, :] - data.subtree_com[self.sys.body_rootid[1:]]\n",
+ " offset = Transform.create(pos=offset)\n",
+ " xd = offset.vmap().do(cvel)\n",
+ " data = _reformat_contact(self.sys, data)\n",
+ " return data.replace(q=q, qd=qd, x=x, xd=xd)\n",
+ "\n",
+ "\n",
+ " # ------------ reward functions----------------\n",
+ " def _reward_reference_tracking(self, x, xd, ref_x, ref_xd):\n",
+ " \"\"\"\n",
+ " Rewards based on inertial-frame body positions.\n",
+ " Notably, we use a high-dimension representation of orientation.\n",
+ " \"\"\"\n",
+ "\n",
+ " f = lambda x, y: ((x - y) ** 2).sum(-1).mean()\n",
+ "\n",
+ " _mse_pos = f(x.pos, ref_x.pos)\n",
+ " _mse_rot = f(quaternion_to_rotation_6d(x.rot),\n",
+ " quaternion_to_rotation_6d(ref_x.rot))\n",
+ " _mse_vel = f(xd.vel, ref_xd.vel)\n",
+ " _mse_ang = f(xd.ang, ref_xd.ang)\n",
+ "\n",
+ " # Tuned to be about the same size.\n",
+ " return _mse_pos \\\n",
+ " + 0.1 * _mse_rot \\\n",
+ " + 0.01 * _mse_vel \\\n",
+ " + 0.001 * _mse_ang\n",
+ "\n",
+ " def _reward_min_reference_tracking(self, ref_qpos, ref_qvel, state):\n",
+ " \"\"\" \n",
+ " Using minimal coordinates. Improves accuracy of joint angle tracking.\n",
+ " \"\"\"\n",
+ " pos = jp.concatenate([\n",
+ " state.pipeline_state.qpos[:3],\n",
+ " state.pipeline_state.qpos[7:]])\n",
+ " pos_targ = jp.concatenate([\n",
+ " ref_qpos[:3],\n",
+ " ref_qpos[7:]])\n",
+ " pos_err = jp.linalg.norm(pos_targ - pos)\n",
+ " vel_err = jp.linalg.norm(state.pipeline_state.qvel- ref_qvel)\n",
+ "\n",
+ " return pos_err + vel_err\n",
+ "\n",
+ " def _reward_feet_height(self, feet_pos, feet_pos_ref):\n",
+ " return jp.sum(jp.abs(feet_pos - feet_pos_ref)) # try to drive it to 0 using the l1 norm.\n",
+ "\n",
+ "envs.register_environment('trotting_anymal', TrotAnymal)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Imitation Learning via FoPG's\n",
+ "Takes 15 minutes on a NVIDIA 3060 TI GPU"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "make_networks_factory = functools.partial(\n",
+ " apg_networks.make_apg_networks,\n",
+ " hidden_layer_sizes=(256, 128)\n",
+ ")\n",
+ "\n",
+ "epochs = 499\n",
+ "\n",
+ "train_fn = functools.partial(apg.train,\n",
+ " episode_length=240,\n",
+ " policy_updates=epochs,\n",
+ " horizon_length=32,\n",
+ " num_envs=64,\n",
+ " learning_rate=1e-4,\n",
+ " num_eval_envs=64,\n",
+ " num_evals=10 + 1,\n",
+ " use_float64=True,\n",
+ " normalize_observations=True,\n",
+ " network_factory=make_networks_factory)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "x_data = []\n",
+ "y_data = []\n",
+ "ydataerr = []\n",
+ "times = [datetime.now()]\n",
+ "\n",
+ "def progress(it, metrics):\n",
+ " times.append(datetime.now())\n",
+ " x_data.append(it)\n",
+ " y_data.append(metrics['eval/episode_reward'])\n",
+ " ydataerr.append(metrics['eval/episode_reward_std'])\n",
+ "\n",
+ "# Each foot contacts the ground twice/sec.\n",
+ "env = envs.get_environment(\"trotting_anymal\", step_k = 13)\n",
+ "eval_env = envs.get_environment(\"trotting_anymal\", step_k = 13)\n",
+ "\n",
+ "make_inference_fn, params, _= train_fn(environment=env,\n",
+ " progress_fn=progress,\n",
+ " eval_env=eval_env)\n",
+ "\n",
+ "plt.errorbar(x_data, y_data, yerr=ydataerr)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "demo_env = envs.training.EpisodeWrapper(env, \n",
+ " episode_length=1000, \n",
+ " action_repeat=1)\n",
+ "\n",
+ "render_rollout(\n",
+ " jax.jit(demo_env.reset),\n",
+ " jax.jit(demo_env.step),\n",
+ " jax.jit(make_inference_fn(params)),\n",
+ " demo_env,\n",
+ " n_steps=200,\n",
+ " seed=1\n",
+ ")\n",
+ "\n",
+ "model_path = '/tmp/trotting_2hz_policy'\n",
+ "model.save_params(model_path, params)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**A note on sample efficiency**\n",
+ "\n",
+ "Above, we train using epochs * horizon_length * num_envs = 1.024e6 total simulator steps. Let's see what we get from PPO, using ten times more samples:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Text(0, 0.5, 'reward per episode')"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "train_fn = functools.partial(\n",
+ " ppo.train, num_timesteps=10_000_000, num_evals=10, reward_scaling=0.1,\n",
+ " episode_length=1000, normalize_observations=True, action_repeat=1,\n",
+ " unroll_length=10, num_minibatches=32, num_updates_per_batch=8,\n",
+ " discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=1024,\n",
+ " batch_size=1024, seed=0)\n",
+ "\n",
+ "x_data = []\n",
+ "y_data = []\n",
+ "ydataerr = []\n",
+ "env = envs.get_environment(\"trotting_anymal\", step_k = 13)\n",
+ "\n",
+ "def progress(num_steps, metrics):\n",
+ " x_data.append(num_steps)\n",
+ " y_data.append(metrics['eval/episode_reward'])\n",
+ " ydataerr.append(metrics['eval/episode_reward_std'])\n",
+ "\n",
+ "make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)\n",
+ "\n",
+ "plt.errorbar(x_data, y_data, yerr=ydataerr)\n",
+ "plt.xlabel('# environment steps')\n",
+ "plt.ylabel('reward per episode')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We see that PPO takes around 9e6 simulator steps to catch up to APG."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Study 2: Quadruped Locomotion\n",
+ "\n",
+ "As we saw in the imitation learning example, FoPG methods benefit from detailed reward signals. To teach locomotion, we reward the feet based on the Raibert Heuristic. Similarly to [prior work](https://arxiv.org/abs/2403.14864), we use a gait schedule to incentivize opposite pairs of legs to move in sync at a fixed frequency. At the beginning of a new scheduled step, we calculate the target position for the feet at the end of the step. \n",
+ "\n",
+ "For each foot, we calculate:\n",
+ "\n",
+ "$$\n",
+ "p^* = h_0 + \\frac{\\Delta T}{2} v_0\n",
+ "$$\n",
+ "\n",
+ "Where $p^*$ is x, y component of the foot's target position, $h_0$ is the x, y component of the corresponding hip at lift-off, $\\Delta T$ is the scheduled step duration, and $v_0$ is the base velocity at lift-off. \n",
+ "\n",
+ "Due to their limited exploration power, FoPG methods benefit greatly from having an good \"initial guess\" of the policy - familiar terminology in Model Predictive Control and Trajectory Optimization. We formulate the problem as [residual learning](https://arxiv.org/abs/1512.03385). Let $\\phi$ be the parameters of a baseline policy that we already have, and let $f$ and $g$ be the neural networks for the learned and baseline policy. We freeze $\\phi$ and learn parameters $\\theta$, for the policy:\n",
+ "\n",
+ "$$\n",
+ "a_t = f(g(x_t; \\phi), x_t; \\theta) + g(x_t; \\phi)\n",
+ "$$\n",
+ "\n",
+ "We use the in-place trotting policy from last section as $\\phi$ and $x_t$ denotes the state at time t. In this example we track a 0.75 m/s velocity target, but since locomotion is stabler at faster trots, you can experiment with $\\phi$ for faster velocity targets!\n",
+ "\n",
+ "While it might seem more natural to \"hotstart\" the learning by simply initializing the parameters $\\theta$ as $\\phi$, with the policy $a_t = f(x_t; \\theta)$, we find that the residual method trains more stably in practice."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### The RL Environment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def axis_angle_to_quaternion(v: jp.ndarray, theta:jp.float_):\n",
+ " \"\"\" \n",
+ " axis angle representation: rotation of theta around v. \n",
+ " \"\"\" \n",
+ " return jp.concatenate([jp.cos(0.5*theta).reshape(1), jp.sin(0.5*theta)*v.reshape(3)])\n",
+ "\n",
+ "def get_config():\n",
+ " \"\"\"Returns reward config for anymal quadruped environment.\"\"\"\n",
+ "\n",
+ " def get_default_rewards_config():\n",
+ " default_config = config_dict.ConfigDict(\n",
+ " dict(\n",
+ " scales=config_dict.ConfigDict(\n",
+ " dict(\n",
+ " tracking_lin_vel = 1.0,\n",
+ " orientation = -1.0, # non-flat base\n",
+ " height = 0.5,\n",
+ " lin_vel_z=-1.0, # prevents the suicide policy\n",
+ " torque = -0.01,\n",
+ " feet_pos = -1, # Bad action hard-coding. \n",
+ " feet_height = -1, # prevents it from just standing still\n",
+ " joint_velocity = -0.001\n",
+ " )\n",
+ " ),\n",
+ " )\n",
+ " )\n",
+ " return default_config\n",
+ "\n",
+ " default_config = config_dict.ConfigDict(\n",
+ " dict(rewards=get_default_rewards_config(),))\n",
+ "\n",
+ " return default_config\n",
+ "\n",
+ "class FwdTrotAnymal(PipelineEnv):\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " termination_height: float=0.25,\n",
+ " **kwargs,\n",
+ " ):\n",
+ " \n",
+ " self.target_vel = kwargs.pop('target_vel', 0.75)\n",
+ " step_k = kwargs.pop('step_k', 25)\n",
+ " self.baseline_inference_fn = kwargs.pop(\"baseline_inference_fn\")\n",
+ " physics_steps_per_control_step = 10\n",
+ " kwargs['n_frames'] = kwargs.get(\n",
+ " 'n_frames', physics_steps_per_control_step)\n",
+ " self.termination_height = termination_height\n",
+ "\n",
+ " mj_model = mujoco.MjModel.from_xml_path(xml_path)\n",
+ " kp = 230\n",
+ " mj_model.actuator_gainprm[:, 0] = kp\n",
+ " mj_model.actuator_biasprm[:, 1] = -kp\n",
+ " self._init_q = mj_model.keyframe('standing').qpos\n",
+ " self._default_ap_pose = mj_model.keyframe('standing').qpos[7:]\n",
+ " self.reward_config = get_config()\n",
+ "\n",
+ " self.action_loc = self._default_ap_pose\n",
+ " self.action_scale = jp.array([0.2, 0.8, 0.8] * 4)\n",
+ " \n",
+ " self.target_h = self._init_q[2]\n",
+ "\n",
+ " sys = mjcf.load_model(mj_model)\n",
+ " super().__init__(sys=sys, **kwargs)\n",
+ " \n",
+ " \"\"\"\n",
+ " Kinematic references are used for gait scheduling.\n",
+ " \"\"\"\n",
+ "\n",
+ " kinematic_ref_qpos = make_kinematic_ref(\n",
+ " cos_wave, step_k, scale=0.3, dt=self.dt)\n",
+ " self.l_cycle = jp.array(kinematic_ref_qpos.shape[0])\n",
+ " self.kinematic_ref_qpos = jp.array(kinematic_ref_qpos + self._default_ap_pose)\n",
+ "\n",
+ " \"\"\"\n",
+ " Foot tracking\n",
+ " \"\"\"\n",
+ " gait_k = step_k * 2\n",
+ " self.gait_period = gait_k * self.dt\n",
+ "\n",
+ " self.step_k = step_k\n",
+ " self.feet_inds = jp.array([21,28,35,42]) # LF, RF, LH, RH\n",
+ " self.hip_inds = self.feet_inds - 6\n",
+ "\n",
+ " self.pipeline_step = jax.checkpoint(self.pipeline_step,\n",
+ " policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)\n",
+ " \n",
+ " def reset(self, rng: jax.Array) -> State:\n",
+ " rng, key_xyz, key_ang, key_ax, key_q, key_qd = jax.random.split(rng, 6)\n",
+ "\n",
+ " qpos = jp.array(self._init_q)\n",
+ " qvel = jp.zeros(18)\n",
+ " \n",
+ " #### Add Randomness ####\n",
+ " \n",
+ " r_xyz = 0.2 * (jax.random.uniform(key_xyz, (3,))-0.5)\n",
+ " r_angle = (jp.pi/12) * (jax.random.uniform(key_ang, (1,)) - 0.5) # 15 deg range\n",
+ " r_axis = (jax.random.uniform(key_ax, (3,)) - 0.5)\n",
+ " r_axis = r_axis / jp.linalg.norm(r_axis)\n",
+ " r_quat = axis_angle_to_quaternion(r_axis, r_angle)\n",
+ "\n",
+ " r_joint_q = 0.2 * (jax.random.uniform(key_q, (12,)) - 0.5)\n",
+ " r_joint_qd = 0.1 * (jax.random.uniform(key_qd, (12,)) - 0.5)\n",
+ " \n",
+ " qpos = qpos.at[0:3].set(qpos[0:3] + r_xyz)\n",
+ " qpos = qpos.at[3:7].set(r_quat)\n",
+ " qpos = qpos.at[7:19].set(qpos[7:19] + r_joint_q)\n",
+ " qvel = qvel.at[6:18].set(qvel[6:18] + r_joint_qd)\n",
+ " \n",
+ " data = self.pipeline_init(qpos, qvel)\n",
+ "\n",
+ " # Ensure you're not sunken into the ground nor above it.\n",
+ " pen = jp.min(data.contact.dist)\n",
+ " qpos = qpos.at[2].set(qpos[2] - pen)\n",
+ " data = self.pipeline_init(qpos, qvel)\n",
+ "\n",
+ " state_info = {\n",
+ " 'rng': rng,\n",
+ " 'steps': 0.0,\n",
+ " 'reward_tuple': {\n",
+ " 'tracking_lin_vel': 0.0,\n",
+ " 'orientation': 0.0,\n",
+ " 'height': 0.0,\n",
+ " 'lin_vel_z': 0.0,\n",
+ " 'torque': 0.0,\n",
+ " 'joint_velocity': 0.0,\n",
+ " 'feet_pos': 0.0,\n",
+ " 'feet_height': 0.0\n",
+ " },\n",
+ " 'last_action': jp.zeros(12), # from MJX tutorial.\n",
+ " 'baseline_action': jp.zeros(12),\n",
+ " 'xy0': jp.zeros((4, 2)),\n",
+ " 'k0': 0.0,\n",
+ " 'xy*': jp.zeros((4, 2))\n",
+ " }\n",
+ "\n",
+ " x, xd = data.x, data.xd\n",
+ " _obs = self._get_obs(data.qpos, x, xd, state_info) # inner obs; to trotter\n",
+ " \n",
+ " action_key, key = jax.random.split(state_info['rng'])\n",
+ " state_info['rng'] = key\n",
+ " next_action, _ = self.baseline_inference_fn(_obs, action_key)\n",
+ "\n",
+ " obs = jp.concatenate([_obs, next_action])\n",
+ "\n",
+ " reward, done = jp.zeros(2)\n",
+ " metrics = {}\n",
+ " for k in state_info['reward_tuple']:\n",
+ " metrics[k] = state_info['reward_tuple'][k]\n",
+ " state = State(data, obs, reward, done, metrics, state_info)\n",
+ " return jax.lax.stop_gradient(state)\n",
+ "\n",
+ " def step(self, state: State, action: jax.Array) -> State:\n",
+ "\n",
+ " action = jp.clip(action, -1, 1)\n",
+ "\n",
+ " cur_base = state.obs[-12:]\n",
+ " action += cur_base\n",
+ " state.info['baseline_action'] = cur_base\n",
+ "\n",
+ " action = self.action_loc + (action * self.action_scale)\n",
+ "\n",
+ " data = self.pipeline_step(state.pipeline_state, action)\n",
+ " \n",
+ " # observation data\n",
+ " x, xd = data.x, data.xd\n",
+ " obs = self._get_obs(data.qpos, x, xd, state.info)\n",
+ "\n",
+ " # Terminate if flipped over or fallen down.\n",
+ " done = 0.0\n",
+ " done = jp.where(x.pos[0, 2] < self.termination_height, 1.0, done)\n",
+ " up = jp.array([0.0, 0.0, 1.0])\n",
+ " done = jp.where(jp.dot(math.rotate(up, x.rot[0]), up) < 0, 1.0, done)\n",
+ "\n",
+ " #### Foot Position Reference Updating ####\n",
+ "\n",
+ " # Detect the start of a new step\n",
+ " s = state.info['steps']\n",
+ " step_num = s // (self.step_k)\n",
+ " even_step = step_num % 2 == 0\n",
+ " new_step = (s % self.step_k) == 0\n",
+ " new_even_step = jp.logical_and(new_step, even_step)\n",
+ " new_odd_step = jp.logical_and(new_step, jp.logical_not(even_step))\n",
+ "\n",
+ " # Apply railbert heuristic to calculate target foot position, after step\n",
+ " hip_xy = data.geom_xpos[self.hip_inds][:,:2] # 4 x 2\n",
+ " v_body = data.qvel[0:2]\n",
+ " step_period = self.gait_period/2\n",
+ " raibert_xy = hip_xy + (step_period/2) * v_body\n",
+ "\n",
+ " # Update. \n",
+ " cur_tars = state.info['xy*']\n",
+ " i_RFLH = jp.array([1, 2])\n",
+ " i_LFRH = jp.array([0, 3])\n",
+ " feet_xy = data.geom_xpos[self.feet_inds][:,:2]\n",
+ " \n",
+ " # With the trotting gait, we will move one pair of opposite legs, \n",
+ " # and keep the other pair fixed in place.\n",
+ " case_c1 = raibert_xy.at[i_LFRH].set(feet_xy[i_LFRH]) \n",
+ " case_c2 = raibert_xy.at[i_RFLH].set(feet_xy[i_RFLH])\n",
+ " xy_tars = jp.where(new_even_step, case_c1, cur_tars)\n",
+ " xy_tars = jp.where(new_odd_step, case_c2, xy_tars)\n",
+ " state.info['xy*'] = xy_tars\n",
+ "\n",
+ " # Save timestep and location at start of step.\n",
+ " state.info['k0'] = jp.where(new_step,\n",
+ " state.info['steps'],\n",
+ " state.info['k0'])\n",
+ " state.info['xy0'] = jp.where(new_step, \n",
+ " feet_xy,\n",
+ " state.info['xy0'])\n",
+ "\n",
+ " # reward\n",
+ " reward_tuple = {\n",
+ " 'tracking_lin_vel': (\n",
+ " self._reward_tracking_lin_vel(jp.array([self.target_vel, 0, 0]), x, xd)\n",
+ " * self.reward_config.rewards.scales.tracking_lin_vel\n",
+ " ),\n",
+ " 'orientation': (\n",
+ " self._reward_orientation(x)\n",
+ " * self.reward_config.rewards.scales.orientation\n",
+ " ),\n",
+ " 'lin_vel_z': (\n",
+ " self._reward_lin_vel_z(xd)\n",
+ " * self.reward_config.rewards.scales.lin_vel_z\n",
+ " ),\n",
+ " 'height': (\n",
+ " self._reward_height(data.qpos) \n",
+ " * self.reward_config.rewards.scales.height\n",
+ " ),\n",
+ " 'torque': (\n",
+ " self._reward_action(data.qfrc_actuator)\n",
+ " * self.reward_config.rewards.scales.torque\n",
+ " ),\n",
+ " 'joint_velocity': (\n",
+ " self._reward_joint_velocity(data.qvel)\n",
+ " * self.reward_config.rewards.scales.joint_velocity\n",
+ " ),\n",
+ " 'feet_pos': (\n",
+ " self._reward_feet_pos(data, state)\n",
+ " * self.reward_config.rewards.scales.feet_pos\n",
+ " ),\n",
+ " 'feet_height': (\n",
+ " self._reward_feet_height(data, state.info)\n",
+ " * self.reward_config.rewards.scales.feet_height\n",
+ " )\n",
+ " }\n",
+ " \n",
+ " reward = sum(reward_tuple.values())\n",
+ "\n",
+ " # state management\n",
+ " state.info['reward_tuple'] = reward_tuple\n",
+ " state.info['last_action'] = action\n",
+ "\n",
+ " for k in state.info['reward_tuple'].keys():\n",
+ " state.metrics[k] = state.info['reward_tuple'][k]\n",
+ "\n",
+ " # next action\n",
+ " action_key, key = jax.random.split(state.info['rng'])\n",
+ " state.info['rng'] = key\n",
+ " next_action, _ = self.baseline_inference_fn(obs, action_key)\n",
+ " obs = jp.concatenate([obs, next_action])\n",
+ "\n",
+ " state = state.replace(\n",
+ " pipeline_state=data, obs=obs, reward=reward,\n",
+ " done=done)\n",
+ " return state\n",
+ "\n",
+ " def _get_obs(self, qpos: jax.Array, x: Transform, xd: Motion,\n",
+ " state_info: Dict[str, Any]) -> jax.Array:\n",
+ "\n",
+ " inv_base_orientation = math.quat_inv(x.rot[0])\n",
+ " local_rpyrate = math.rotate(xd.ang[0], inv_base_orientation)\n",
+ "\n",
+ " obs_list = []\n",
+ " # yaw rate\n",
+ " obs_list.append(jp.array([local_rpyrate[2]]) * 0.25)\n",
+ " # projected gravity\n",
+ " obs_list.append(\n",
+ " math.rotate(jp.array([0.0, 0.0, -1.0]), inv_base_orientation))\n",
+ " # motor angles\n",
+ " angles = qpos[7:19]\n",
+ " obs_list.append(angles - self._default_ap_pose)\n",
+ " # last action\n",
+ " obs_list.append(state_info['last_action'])\n",
+ " # gait schedule\n",
+ " kin_ref = self.kinematic_ref_qpos[jp.array(state_info['steps']%self.l_cycle, int)]\n",
+ " obs_list.append(kin_ref)\n",
+ "\n",
+ " obs = jp.clip(jp.concatenate(obs_list), -100.0, 100.0)\n",
+ "\n",
+ " return obs\n",
+ "\n",
+ " # ------------ reward functions----------------\n",
+ " def _reward_tracking_lin_vel(\n",
+ " self, commands: jax.Array, x: Transform, xd: Motion) -> jax.Array:\n",
+ " # Tracking of linear velocity commands (xy axes)\n",
+ " local_vel = math.rotate(xd.vel[0], math.quat_inv(x.rot[0]))\n",
+ " lin_vel_error = jp.sum(jp.square(commands[:2] - local_vel[:2]))\n",
+ " lin_vel_reward = jp.exp(-lin_vel_error)\n",
+ " return lin_vel_reward\n",
+ " def _reward_orientation(self, x: Transform) -> jax.Array:\n",
+ " # Penalize non flat base orientation\n",
+ " up = jp.array([0.0, 0.0, 1.0])\n",
+ " rot_up = math.rotate(up, x.rot[0])\n",
+ " return jp.sum(jp.square(rot_up[:2]))\n",
+ " def _reward_lin_vel_z(self, xd: Motion) -> jax.Array:\n",
+ " # Penalize z axis base linear velocity\n",
+ " return jp.clip(jp.square(xd.vel[0, 2]), 0, 10)\n",
+ " def _reward_joint_velocity(self, qvel):\n",
+ " return jp.clip(jp.sqrt(jp.sum(jp.square(qvel[6:]))), 0, 100)\n",
+ " def _reward_height(self, qpos) -> jax.Array:\n",
+ " return jp.exp(-jp.abs(qpos[2] - self.target_h)) # Not going to be > 1 meter tall.\n",
+ " def _reward_action(self, action) -> jax.Array:\n",
+ " return jp.sqrt(jp.sum(jp.square(action)))\n",
+ " def _reward_feet_pos(self, data, state): \n",
+ " dt = (state.info['steps'] - state.info['k0']) * self.dt # scalar\n",
+ " step_period = self.gait_period / 2\n",
+ " xyt = state.info['xy0'] + (state.info['xy*'] - state.info['xy0']) * (dt/step_period)\n",
+ "\n",
+ " feet_pos = data.geom_xpos[self.feet_inds][:, :2]\n",
+ "\n",
+ " rews = jp.sum(jp.square(feet_pos - xyt), axis=1) \n",
+ " rews = jp.clip(rews, 0, 10)\n",
+ " return jp.sum(rews)\n",
+ " def _reward_feet_height(self, data, state_info):\n",
+ " \"\"\" \n",
+ " Feet height tracks rectified sine waves \n",
+ " \"\"\"\n",
+ " h_tar = 0.1\n",
+ " t = state_info['steps'] * self.dt\n",
+ " offset = self.gait_period/2\n",
+ " ref1 = jp.sin((2*jp.pi/self.gait_period)*t) # RF and LH feet\n",
+ " ref2 = jp.sin((2*jp.pi/self.gait_period)*(t - offset)) # LF and RH\n",
+ " \n",
+ " ref1, ref2 = ref1 * h_tar, ref2 * h_tar\n",
+ " h_tars = jp.array([ref2, ref1, ref1, ref2])\n",
+ " h_tars = h_tars.clip(min=0, max=None) + 0.02 # offset height of feet.\n",
+ " \n",
+ " feet_height = data.geom_xpos[self.feet_inds][:,2]\n",
+ " errs = jp.clip(jp.square(feet_height - h_tars), 0, 10)\n",
+ " return jp.sum(errs)\n",
+ " \n",
+ "envs.register_environment('anymal', FwdTrotAnymal)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Residual Learning via FoPG's\n",
+ "Takes 15 minutes on a NVIDIA 3060 TI GPU"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Reconstruct the trotting inference function\n",
+ "make_networks_factory = functools.partial(\n",
+ " apg_networks.make_apg_networks,\n",
+ " hidden_layer_sizes=(256, 128)\n",
+ ")\n",
+ "\n",
+ "nets = make_networks_factory(observation_size=1, # Observation_size argument doesn't matter since it's only used for param init.\n",
+ " action_size=12,\n",
+ " preprocess_observations_fn=running_statistics.normalize)\n",
+ "\n",
+ "make_inference_fn = apg_networks.make_inference_fn(nets)\n",
+ "\n",
+ "# Configure locomotion training\n",
+ "make_networks_factory = functools.partial(\n",
+ " apg_networks.make_apg_networks,\n",
+ " hidden_layer_sizes=(128, 64)\n",
+ ")\n",
+ "\n",
+ "epochs = 499\n",
+ "\n",
+ "train_fn = functools.partial(apg.train,\n",
+ " episode_length=1000,\n",
+ " policy_updates=epochs,\n",
+ " horizon_length=32,\n",
+ " num_envs=64,\n",
+ " learning_rate=1.5e-4,\n",
+ " schedule_decay=0.995,\n",
+ " num_eval_envs=64,\n",
+ " num_evals=10 + 1,\n",
+ " use_float64=True,\n",
+ " normalize_observations=True,\n",
+ " network_factory=make_networks_factory)\n",
+ "\n",
+ "model_path = '/tmp/trotting_2hz_policy'\n",
+ "params = model.load_params(model_path)\n",
+ "baseline_inference_fn = make_inference_fn(params)\n",
+ "\n",
+ "env_kwargs = dict(target_vel=0.75, step_k=13, \n",
+ " baseline_inference_fn=baseline_inference_fn)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjMAAAGdCAYAAADnrPLBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA/OElEQVR4nO3deXhU5f3//1fWyUL2QEIggaBBlrCEBC2gAmVTAWu1AqIov1orBZcYFEGtIl9JBBVUaFGsn2qxCFpFbZVKtBakuEBIWAVElgSSEMCQScgySeb8/oiMhkW2TM7M5Pm4rrnMnLknec9UZ169z/vct5dhGIYAAADclLfZBQAAAFwMwgwAAHBrhBkAAODWCDMAAMCtEWYAAIBbI8wAAAC3RpgBAABujTADAADcmq/ZBTQHu92uwsJChYSEyMvLy+xyAADAOTAMQ+Xl5YqLi5O395nnX1pEmCksLFR8fLzZZQAAgAtQUFCg9u3bn/HxFhFmQkJCJDW8GaGhoSZXAwAAzoXValV8fLzje/xMWkSYOXFqKTQ0lDADAICbOVuLCA3AAADArRFmAACAWyPMAAAAt0aYAQAAbo0wAwAA3BphBgAAuDWnhpm6ujo99thjSkxMVGBgoDp16qRZs2bJbrc7xhiGoZkzZyouLk6BgYEaNGiQtm3b1uj31NTU6N5771V0dLSCg4N1/fXX68CBA84sHQAAuAmnhpk5c+bopZde0sKFC/XNN99o7ty5euaZZ7RgwQLHmLlz52revHlauHCh1q9fr9jYWA0bNkzl5eWOMenp6VqxYoWWLVumtWvXqqKiQqNGjVJ9fb0zywcAAG7AyzAMw1m/fNSoUYqJidGrr77qOHbTTTcpKChIS5YskWEYiouLU3p6uh5++GFJDbMwMTExmjNnju6++26VlZWpdevWWrJkicaOHSvpx+0JPvroI40YMeKsdVitVoWFhamsrIxF8wAAcBPn+v3t1JmZK6+8Up9++ql27dolSdq0aZPWrl2r6667TpK0d+9eFRcXa/jw4Y7nWCwWDRw4UOvWrZMk5eTkqLa2ttGYuLg4JScnO8acrKamRlartdENAAB4JqduZ/Dwww+rrKxMXbp0kY+Pj+rr6zV79mzdcsstkqTi4mJJUkxMTKPnxcTEaP/+/Y4x/v7+ioiIOGXMieefLCsrS08++WRTvxwAAOCCnDozs3z5cr3xxhtaunSpNm7cqNdff13PPvusXn/99UbjTt5zwTCMs+7D8HNjZsyYobKyMsetoKDg4l4IAABwWU6dmXnooYc0ffp0jRs3TpLUo0cP7d+/X1lZWbrjjjsUGxsrqWH2pW3bto7nlZSUOGZrYmNjZbPZVFpa2mh2pqSkRP379z/t37VYLLJYLM56WQAAwIU4dWamsrJS3t6N/4SPj4/j0uzExETFxsYqOzvb8bjNZtPq1asdQSU1NVV+fn6NxhQVFWnr1q1nDDMAAMD5Km116jj9Q3Wc/qEqbXWm1eHUmZnRo0dr9uzZSkhIUPfu3ZWbm6t58+bpt7/9raSG00vp6enKzMxUUlKSkpKSlJmZqaCgII0fP16SFBYWpjvvvFNTp05VVFSUIiMj9eCDD6pHjx4aOnSoM8sHAABuwKlhZsGCBfrjH/+oyZMnq6SkRHFxcbr77rv1+OOPO8ZMmzZNVVVVmjx5skpLS3XFFVdo1apVCgkJcYyZP3++fH19NWbMGFVVVWnIkCF67bXX5OPj48zyAQCAG3DqOjOugnVmAABoepW2OnV7/GNJ0vZZIxTk37RzJC6xzgwAAGZwlV4ONA+nnmYCAADnxtmzHBei3m7oWKVNR4/bdKS8RkeO23S0okZHK2w6erxGxdZqx9h3Nh7QhF90NKVO898pAADQbKps9TpSUaMjPwklRypsje4f/eH+98dtsp9jM0pxWfXZBzkJYQYAADd2YvbkSEXDrMmJ2ZMT4eSII7A03K+0nf8mzRFBfopqZVFUsL+iQyyKDvZXVCuLQgJ89eQ/t0uSxvaNb+qXds4IMwAAuJhKW51jdsTxz+ONZ0+OlDf883xmT07w9/VW61YWRbfyd4SUqB/uR7eyKKqVv6KCLYoO8VdEkL/8fE7fYltpq3OEmehW5i1WS5gBAMBEdruh7UVWffLNIcextKc+Pe/fExHk92MQadUwe9Jwv+HYj0HFomB/n7NuG+ROCDMAADSzEmu11nx7RJ9/e1hrvz2io8dtp4yx+Hor+qTZk+iQH/7pCCgN/4wM8pfvGWZPWgLCDAAATlZdW68N+0q15tvDWrPrsHYUlzd6PNjfR5d3itRnOw5LktY/OkTRrSweNXviTIQZAACamGEY2l1SodW7Duvzb4/oq71HVV1rdzzu5SX1aBemq5KidXVSa6UkRKjObndcmh1s8XWLIBPk76t9T480uwzCDAAATaH0uE3/++6I1vwQYIpOulQ5JtSiq5Ja6+rOrXXlpdGKDPZv9HidzS5cGMIMAAAXoLberryCY1qz67DWfHtEmw8c0083CLL4euvyxEgN7NxaVyW1VueYVm4x2+KOCDMAAJyj/KOVWv3tYX2+67C++O6oymsab5VwWUxIw6mjzq11eWKkAvzYELk5EGYAADiDipo6ffHd0R9mXw5r/9HKRo9HBPnpyqTWujopWlcltVZsWMAF/y1X6T9xR4QZAAB+YLcb2lpY5jh1tHF/qep+siKdr7eX+nSI+OHUUbSS48Lk7c2pI7MRZgAALVpxWbU+/7YhvKz99rBKK2sbPd4xKsjRuNvvkii1svDV6Wr4XwQA0KJU19br673fO6462nmo8ZovrSy+6n9JlK7u3FpXJ7VWQlSQSZXiXBFmAAAezTAM7Swu1+ffHtbqXYf19d7vVVPXeM2Xnu3DNTApWld1bq3e8eFn3IsIrokwAwDwKIZhaN+R4477g59drZLymkZj2oYFOK46GnBJtCJOWvMF7oUwAwBwa+XVtdpUUKbc/FJtzC9VbsExHftJ30tJeY0C/Lx1ReKJU0fRurQNa754EsIMAMBt2O2G9hw53hBa8kuVm39MOw+VN1qsTmpYsO7EqaS/3JGmKy+NZs0XD0aYAQD8rEpbnWPPoO2zRijIv/m+OqzVtdpUcEwb9x/TxvxS5RUcU1lV7Snj2kcEqk9ChPokhCslIUIdooLUe1a2JKn/JVEEGQ9HmAEAuAS73dB3hyt+mHVpCC/fllScMusS4Oetnu3DlZIQrj4JEUpJCFebkMaL1VXaGq/MC89GmAEAmKKsslZ5B45p4/5Sx6xLefWpISQhMkh9EsLVp0OEUuIj1KVtCFcboRHCDADA6ex2Q9+WVDh6XTbmH9PukopTxgX6+ahn+zD16RChPgkR6h0frtYhFhMqhjshzAAAmtyxSptyC44pd39DcNlUcOyUTRmlhtV1T5wqSkmIUJfYEPky64LzRJgBAFyUeruhXYfKHX0uG/NLtefw8VPGBfn7qFf7cPXpEO6YdYlqxawLLh5hBgBwXkqP25RbUKqN+48pt6BUmwrKVHGaWZdO0cHq/UOTbp+ECHWOacWsC5yCMAMAOKPaeru2F1od96994XPtP1p5yrhgfx9HcElJCFdKfASr6qLZEGYAAJIatgE4UFqlvIJjjtvWg2WN9jE6EWQ6tQ52zLj06RCupDYh8vFmRV2YgzADAC1UWVWtNh84prz8huCy6cAxHamwnTIuNMBX1h8umX7ptj76RacohQcx6wLXQZgBgBagtt6uHUXlynOEl1J9d5omXT8fL3VtG6re8eGOW5sQi5JnrpIkXd25dbOuAAycC/6NBIBm1BxbA5zL6aITEiKD1OsnwaV7XOgpS/+zmi5cHWEGANzc+Zwu6hUfrpT4cPVOCFev9lwaDc9AmAEAN3Ixp4s6RgXLmyZdeCDCDAC4qPM9XdQ7Ptxxyuh0p4sAT0WYAQAXweki4MIQZgDAJNsKy7SjqFy5BQ17F53r6aLE6GB5eXG6CDjByzAMw+winM1qtSosLExlZWUKDQ01uxwALVRNXb3+talQU9/efMYxnC4CfnSu39/MzACAk31TZNXy9QV6L++gjlXWOo5zughoGoQZAHACa3WtPsgr1FsbCrT5QJnjeEyoRYesNZKkL2b8UsEWP7NKBDwGYQYAmohhGPpq7/d6a32BPtpapOrahquO/Hy8NLRrjMb0jVdahwj1+GE1XfpegKZBmAGAi3TIWq1/5BzQ2xsKtO8nO0ontWmlsX3j9euUdo7TR6ymCzQ9wgwAXIDaerv+s6NEb60v0Gc7S2T/4VKKYH8fje4VpzF945USH87sC9AMvJ39Bw4ePKjbbrtNUVFRCgoKUu/evZWTk+N43DAMzZw5U3FxcQoMDNSgQYO0bdu2Rr+jpqZG9957r6KjoxUcHKzrr79eBw4ccHbpAHCK3SUVyvzoG/XL+lR3L8nRpzsagkzfjhF65jc99fWjQ/X0TT3VJyGCIAM0E6fOzJSWlmrAgAEaPHiwVq5cqTZt2ui7775TeHi4Y8zcuXM1b948vfbaa+rcubOeeuopDRs2TDt37lRISIgkKT09Xf/85z+1bNkyRUVFaerUqRo1apRycnLk48MliwCc63hNnT7cXKTlGwqUs7/UcTy6lUU3pbbTmLR4XdK6lYkVAi2bU9eZmT59uv73v//p888/P+3jhmEoLi5O6enpevjhhyU1zMLExMRozpw5uvvuu1VWVqbWrVtryZIlGjt2rCSpsLBQ8fHx+uijjzRixIiz1sE6MwDOl2EY2ph/TG+tL9C/NhfquK1ekuTj7aXBl7XWmLR4De7SRn4+Tp/gBlosl1hn5oMPPtCIESN08803a/Xq1WrXrp0mT56su+66S5K0d+9eFRcXa/jw4Y7nWCwWDRw4UOvWrdPdd9+tnJwc1dbWNhoTFxen5ORkrVu37rRhpqamRjU1NY77VqvVia8SgFkqbXXq9vjHkqTts0YoyP/iP9KOVNRoxcaDWr6hQLtLKhzHE6ODdXNae93Up71iQgMu+u8AaDpODTN79uzRokWLlJGRoUceeURff/217rvvPlksFt1+++0qLi6WJMXExDR6XkxMjPbv3y9JKi4ulr+/vyIiIk4Zc+L5J8vKytKTTz7phFcEwBPV2w2t2XVYy9cX6JNvDqnuh27eAD9vXdejrcamxevyxEh6YAAX5dQwY7fblZaWpszMTElSSkqKtm3bpkWLFun22293jDv5A8IwjLN+aPzcmBkzZigjI8Nx32q1Kj4+/kJfBgAPlX+0Um9tKNA/cg6o2FrtON4rPlxj0tprdK84hQawqB3g6pwaZtq2batu3bo1Ota1a1e98847kqTY2FhJDbMvbdu2dYwpKSlxzNbExsbKZrOptLS00exMSUmJ+vfvf9q/a7FYZLGwJDiAU1XX1mvl1iItX1+gL/d87zgeEeSnX6e015i+7dUllt46wJ04NcwMGDBAO3fubHRs165d6tChgyQpMTFRsbGxys7OVkpKiiTJZrNp9erVmjNnjiQpNTVVfn5+ys7O1pgxYyRJRUVF2rp1q+bOnevM8gF4CMMwtPWgVcs35Ov9vEKVVzcsXOflJV2V1Fpj0+I1tFsbWXy5OhJwR04NMw888ID69++vzMxMjRkzRl9//bUWL16sxYsXS2o4vZSenq7MzEwlJSUpKSlJmZmZCgoK0vjx4yVJYWFhuvPOOzV16lRFRUUpMjJSDz74oHr06KGhQ4c6s3wAbu5YpU3v5R7U8g0H9E3RjxcCtAsP1Ji0eP0mrb3ahQeaWCGApuDUMNO3b1+tWLFCM2bM0KxZs5SYmKjnn39et956q2PMtGnTVFVVpcmTJ6u0tFRXXHGFVq1a5VhjRpLmz58vX19fjRkzRlVVVRoyZIhee+011pgBcAq73dC6745q+YYCfbytWLa6hv2R/H28NSI5VmPT4tX/kih5e9PMC3gKp64z4ypYZwbwTD+9NPuTjKv14eZivZ1ToAOlVY4xXduGamxae92Q0k7hQf5mlQrgArjEOjMA0FyGzV+jE//XLCTAV7/qHaexaQlKbhfKJdWAhyPMAHBLZVW1enTFFsd9w5D6dYrS2L7xuiY5VgF+nIYGWgrCDAC3s273EU19e5OKyn5cG+bf6VdxSTXQQhFmALiN6tp6PfPxTr26dq8kKSEySPnfVzp+BtAysUMaALewrbBM1y9c6wgy469I0Dt/6GdyVQBcATMzAFxavd3Q4jV7NC97p2rrDUW3smjub3rol11iVGmrM7s8AC6AMAPAZRV8X6mpb23S1/sath0Y3i1GWTf2UFQrtisB8CPCDACXYxiG/pFzQE/+c7sqauoU7O+jJ67vrptT23OZNYBTEGYAuJSjFTV6ZMUWfbztkCQprUOE5o/trXgafAGcAWEGgMv4bEeJHvrHZh2pqJGfj5ceGNZZd199iXzYegDAzyDMADBdpa1Osz/8Rn//Kl+SlNSmleaP7a3kdmEmVwbAHRBmAJgqN79UGW9t0t4jxyVJvx2QqGnXXHZOK/gG+ftq39MjnV0iABdHmAFgitp6uxb+Z7cWfrZb9XZDbcMC9OzNvTTg0mizSwPgZggzAJrdd4crlLE8T5sOlEmSftU7TrOuT1ZYkJ/JlQFwR4QZAM3GMAy98eV+zf7oG1XX2hUa4Kunft1D1/eKM7s0AG6MMAOgWZRYq/XQPzZr9a7DkqQBl0bp2Zt7qW1YoMmVAXB3hBkATrdyS5EeWbFFpZW18vf11vRrumhi/47y5pJrAE2AMAPAaazVtZr5wTa9u/GgJKl7XKieH9tbSTEhJlcGwJMQZgA4xVd7jirjrU06eKxK3l7SHwZdovuHdJa/r7fZpQHwMIQZAE2qpq5e81bt0uLP98gwpPjIQM0f01tpHSPNLg2AhyLMAGgyO4qtSl+Wpx3F5ZKksWnx+uPobmpl4aMGgPPwCQPgotnthl5du1fPfLxTtnq7IoP9lXVjD43oHmt2aQBaAMIMAEkN+yN1e/xjSdL2WSMU5H9uHw8Hj1Vp6lt5+nLP95KkIV3a6Ombeqp1iMVptQLATxFmAFwQwzD0Xt5BPf7eNpXX1CnI30d/HNVN4/rGy8uLS64BNB/CDIDzVnrcpsfe26oPtxRJklISwjV/TG91jA42uTIALRFhBsB5Wb3rsB56e5NKymvk6+2l+4ck6Q+DLpGvD5dcAzAHYQbAOamy1evpld/o9S/2S5I6tQ7W82N7q2f7cHMLA9DiEWYAnNXmA8eUvjxPew4flyTd0a+Dpl/bVYH+PiZXBgCEGQA/o67erkX//U4vfPqt6uyG2oRY9MzNvTSwc2uzSwMAB8IMgNPad+S4HngrT7n5xyRJI3u01VM3JCsi2N/cwgDgJIQZAKd4a0OB5qzcqaraeoVYfDXrhu66oXc7LrkG4JIIMwBOMfOD7ZKkX3SK1HNjeqtdeKDJFQHAmRFmAEiSDpRWOn728/HStBFddOeVifL2ZjYGgGsjzABQla1e976Z57j/9qR+6h0fYV5BAHAeWOUKaOEMw9D0dzdr5w87XUtS55gQEysCgPNDmAFauFfX7tX7eYXy4XQSADdFmAFasHXfHVHWyh2SpGnXXGZyNQBwYQgzQAt18FiV7lmaq3q7oRtT2um2KxLMLgkALghhBmiBqmvrNWlJjr4/blP3uFBl3tiDNWQAuC3CDNDCGIahR1Zs0ZaDZYoI8tPLE1IV4MceSwDcF2EGaGH+9sV+vbvxoLy9pD+N76P2EUFmlwQAF4UwA7QgX+05qv/3r4bVfR+5rqv6XxptckUAcPEIM0ALUVRWpSlLN6rObmh0rzjdeWWi2SUBQJMgzAAtQHVtvSa9sVFHKmzqEhuiOTfR8AvAczRbmMnKypKXl5fS09MdxwzD0MyZMxUXF6fAwEANGjRI27Zta/S8mpoa3XvvvYqOjlZwcLCuv/56HThwoLnKBtyeYRh64v1t2lRwTGGBflo8IU1B/qfuZBLk76t9T4/UvqdHnvZxAHBVzRJm1q9fr8WLF6tnz56Njs+dO1fz5s3TwoULtX79esXGxmrYsGEqL/9xWfX09HStWLFCy5Yt09q1a1VRUaFRo0apvr6+OUoH3N7Sr/O1fEOBvL2kBbekKCGKhl8AnsXpYaaiokK33nqrXnnlFUVE/LhxnWEYev755/Xoo4/qxhtvVHJysl5//XVVVlZq6dKlkqSysjK9+uqreu655zR06FClpKTojTfe0JYtW/TJJ584u3TA7eXs/14zP2iY7XxoRBdd3bm1yRUBQNNzepiZMmWKRo4cqaFDhzY6vnfvXhUXF2v48OGOYxaLRQMHDtS6deskSTk5OaqtrW00Ji4uTsnJyY4xp1NTUyOr1droBrQ0h6zVmvTGRtXWG7quR6wmDexkdkkA4BROPTG+bNkybdy4UevXrz/lseLiYklSTExMo+MxMTHav3+/Y4y/v3+jGZ0TY048/3SysrL05JNPXmz5gNuy1dn1hzdydLi8Rp1jWumZ3/Si4ReAx3LazExBQYHuv/9+vfHGGwoICDjjuJM/YA3DOOuH7tnGzJgxQ2VlZY5bQUHB+RUPuLkn/7lNG/OPKTTAV4snpCnYQkMvAM/ltDCTk5OjkpISpaamytfXV76+vlq9erVefPFF+fr6OmZkTp5hKSkpcTwWGxsrm82m0tLSM445HYvFotDQ0EY3oKVY9nW+/v5Vvry8pBfGpahjdLDZJQGAUzktzAwZMkRbtmxRXl6e45aWlqZbb71VeXl56tSpk2JjY5Wdne14js1m0+rVq9W/f39JUmpqqvz8/BqNKSoq0tatWx1jAPwoN79Uj7/f0PCbMbSzBndpY3JFAOB8Tpt7DgkJUXJycqNjwcHBioqKchxPT09XZmamkpKSlJSUpMzMTAUFBWn8+PGSpLCwMN15552aOnWqoqKiFBkZqQcffFA9evQ4paEYaOkOl9foD29slK3eruHdYjRl8KVmlwQAzcLUE+nTpk1TVVWVJk+erNLSUl1xxRVatWqVQkJCHGPmz58vX19fjRkzRlVVVRoyZIhee+01+fiwyy9wQm29XVP+vlHF1mpd0jpYz43pJW9vGn4BtAxehmEYZhfhbFarVWFhYSorK6N/Bh5p5gfb9Nq6fWpl8dX79wzQJa1bmV0SAFy0c/3+Zm8mwM39I+eAXlu3T5I0f2xvggyAFocwA7ixLQfK9MiKLZKk+4ckaVi3M1/lBwCeijADuKmjFTW6e8kG2ersGtKlje4fkmR2SQBgCsIM4Ibq6u2asnSjCsuqlRgdrHlje9PwC6DFIswAbihr5Q59ued7Bfv7aPGEVIUF+pldEgCYhjADuJn38w7q1bV7JUnPjemlpJiQszwDADwbYQZwI9sKy/TwO5slSVMGX6JrktuaXBEAmI8wA7iJ0uM23b0kR9W1dg3s3FoZwy4zuyQAcAmEGcAN1NXbde+buTpQWqWEyCC9OC5FPjT8AoAkwgzgFp5ZtVNrdx9RoJ+PFt+eqrAgGn4B4ATCDODi/rW5UC+v3iNJeubmnuoSy5YcAPBThBnAhe0otuqhtxsafu++upNG9YwzuSIAcD2EGcBFHau06fd/y1FVbb2uvDRaD42g4RcATocwA7igeruh+5flKf/7SrWPCNSCW1Lk68N/rgBwOnw6Ai5ofvYurd51WAF+3np5Qqoigv3NLgkAXBZhBnAx/95apIWf7ZYkzbmpp7rHhZlcEQC4NsIM4ASVtjp1nP6hOk7/UJW2unN+3reHyjX1rU2SpDuvTNSverdzVokA4DEIM4CLsFbX6vdLcnTcVq9fdIrUjGu7mF0SALgFwgzgAux2Qw8sy9PeI8cVFxagheP70PALAOeIT0vABbzw6bf6dEeJ/H299dKEVEW3sphdEgC4DcIMYLLs7Yf0wqffSpJm35Csnu3DzS0IANwMYQYw0XeHK/TA8jxJ0h39OujmtHhzCwIAN0SYAUxSXl2r3/9tgypq6nR5x0g9Nqqb2SUBgFsizAAmsNsNTX1rk747fFyxoQH606195EfDLwBcED49ARP8+b+7tWr7Ifn7eGvRbX3UOoSGXwC4UIQZoJl9tqNEz2XvkiTN+lV3pSREmFwRALg3wgzQjPYdOa77luXKMKTxVyRo3OUJZpcEAG6PMAM0k+M1dfr9kg0qr65Tn4RwPTGahl8AaAqEGaAZGIahh/6xSbsOVah1iEWLbkuVxdfH7LIAwCMQZoBm8PKaPfpoS7H8fLy06NY+igkNMLskAPAYhBnAyf63+4jm/nuHJOmJ0d2V1jHS5IoAwLMQZgAnm/r2JtkNaWxavG69goZfAGhqhBnAyaxVderVPkxP/qq7vLy8zC4HADwOYQZwAsMwHD9HBftr0W2pCvCj4RcAnIEwAzjB//1vn+Pn+WN7KS480LxiAMDDEWaAJvbZjhLN+2GFX0k0/AKAkxFmgCa0u6Rc973ZsMIvAKB5EGaAJnKs0qbfvb5B5TV1Su3AfksA0FwIM0ATqKu3656ludp3tFLtwgP1wrjeZpcEAC0GYQZoArM/+kZrdx9RkL+PXrk9TZHB/maXBAAtBmEGuEjL1+frrz9cvTRvTC91iws1tyAAaGEIM8BFWL/vez323lZJ0gNDO+ua5LYmVwQALQ9hBrhAB49VadKSHNXWG7quR6zu/eWlZpcEAC2SU8NMVlaW+vbtq5CQELVp00Y33HCDdu7c2WiMYRiaOXOm4uLiFBgYqEGDBmnbtm2NxtTU1Ojee+9VdHS0goODdf311+vAgQPOLB34WZW2Ot31+gYdPW5T17ahevbmXvL2ZqsCADCDU8PM6tWrNWXKFH355ZfKzs5WXV2dhg8fruPHjzvGzJ07V/PmzdPChQu1fv16xcbGatiwYSovL3eMSU9P14oVK7Rs2TKtXbtWFRUVGjVqlOrr651ZPnBahmHowbc3aXuRVVHB/nrl9lQF+fuaXRYAtFhehtF8y3sdPnxYbdq00erVq3X11VfLMAzFxcUpPT1dDz/8sKSGWZiYmBjNmTNHd999t8rKytS6dWstWbJEY8eOlSQVFhYqPj5eH330kUaMGHHWv2u1WhUWFqaysjKFhtKciYvzwiffav4nu+Tn46Wld/1CfU+zwm+lrU7dHv9YkrR91gjCDgBcgHP9/m7WnpmysjJJUmRkw4f/3r17VVxcrOHDhzvGWCwWDRw4UOvWrZMk5eTkqLa2ttGYuLg4JScnO8acrKamRlartdENaAr/3lqk+Z80bFXw1A3Jpw0yAIDm1WxhxjAMZWRk6Morr1RycrIkqbi4WJIUExPTaGxMTIzjseLiYvn7+ysiIuKMY06WlZWlsLAwxy0+Pr6pXw5aoO2FVj2wfJMkaWL/jhrbN8HkigAAUjOGmXvuuUebN2/Wm2++ecpjXl6NGycNwzjl2Ml+bsyMGTNUVlbmuBUUFFx44YCkoxU1uutvG1RVW68rL43WYyO7ml0SAOAHzRJm7r33Xn3wwQf67LPP1L59e8fx2NhYSTplhqWkpMQxWxMbGyubzabS0tIzjjmZxWJRaGhooxtwoWx1dv3h7xt18FiVOkYFaeH4FPn6sKoBALgKp34iG4ahe+65R++++67+85//KDExsdHjiYmJio2NVXZ2tuOYzWbT6tWr1b9/f0lSamqq/Pz8Go0pKirS1q1bHWMAZzEMQ098sE1f7/1erSy++ssdaQoPYqsCAHAlTr3EYsqUKVq6dKnef/99hYSEOGZgwsLCFBgYKC8vL6WnpyszM1NJSUlKSkpSZmamgoKCNH78eMfYO++8U1OnTlVUVJQiIyP14IMPqkePHho6dKgzywe05Mv9evPrfHl5SS/e0luXtgkxuyQAwEmcGmYWLVokSRo0aFCj43/96181ceJESdK0adNUVVWlyZMnq7S0VFdccYVWrVqlkJAfvzTmz58vX19fjRkzRlVVVRoyZIhee+01+fj4OLN8tHDrdh/Rk//cLkl6+Jou+mWX05/WBACYq1nXmTEL68zgfO0/ely/+tP/dKyyVr9Oaad5Y3qdtSkdANC0XHKdGcAdlFfX6nevb9Cxylr1ig9X1o09CDIA4MIIM8BP1NsNpS/L07clFWoTYtHiCakK8ON0JgC4MsIM8BPPrdqpT3eUyN/XW4tvT1NMaIDZJQEAzoIwA/zg/byD+vN/v5Mkzb2pp3rHh5tbEADgnBBmAEmbDxzTtH9sliRNGniJbkhpZ3JFAIBzRZhBi1dirdbv/5ajmjq7ftmljR4acZnZJQEAzgNhBi1adW29fr8kR8XWal3appVeGNdbPt5cuQQA7oQwgxbLMAw98u4W5RUcU1ign/5ye5pCAvzMLgsAcJ4IM2ix/vL5Xr2be1A+3l7686191DE62OySAAAXgDCDFumznSXKWvmNJOmPI7tqwKXRJlcEALhQhBm4vEpbnTpO/1Adp3+oSlvdRf++3SUVum9pruyGNK5vvO7o3/HiiwQAmIYwgxalrLJWd/1tg8pr6tS3Y4Rm/SqZrQoAwM0RZtBi1NXbdc+bG7X3yHG1Cw/UottS5e/LfwIA4O74JEeLkbVyhz7/9ogC/Xy0+PZURbeymF0SAKAJEGbQIry1vkCvrt0rSZo3ppe6x4WZXBEAoKkQZuDxNuz7Xo++t0WSdP+QJF3bo63JFQEAmhJhBh7t4LEqTXojR7X1hq5NjtX9Q5LMLgkA0MQIM/BYVbZ6/f5vG3SkwqaubUP13Jhe8marAgDwOIQZeCTDMPTgPzZpW6FVUcH+euX2VAX5+5pdFgDACQgz8EgL/7NbH24ukp+Plxbdlqr2EUFmlwQAcBLCDDzOv7cW67nsXZKk//erZF2eGGlyRQAAZyLMwKPsKLYq4608SdLE/h017vIEcwsCADgdYQYe4/vjNv3u9Q2qtNVrwKVRemxkV7NLAgA0A8IMPEJtvV1/eCNHB0qr1CEqSH8a30e+PvzrDQAtAZ/28AgzP9imr/Z+r1YWX/3l9jSFB/mbXRIAoJkQZuD2lny5X3//Kl9eXtIL43orKSbE7JIAAM2IMAO3tu67I3ryg22SpGkjumhI1xiTKwIANDfCDNxW/tFKTfn7RtXZDf2qd5wmDexkdkkAABMQZuCWKmrqdNffNqi0sla92odpzk095eXFVgUA0BIRZuB27HZD6cvytPNQudqEWPTyhDQF+PmYXRYAwCSEGbidedm79Mk3h+Tv662XJ6QqNizA7JIAACYizMCtrNxSpIWf7ZYkPX1jD6UkRJhcEQDAbIQZuJVH39sqSbr76k66sU97k6sBALgCwgzcSnWtXYMva61p13QxuxQAgIsgzMDl2ersjp87RQfrhVtS5OPNlUsAgAaEGbi851btdPy88NYUhQb4mVgNAMDVEGbg0j795pCWfJnvuN8xKtjEagAArogwA5d1yFqtB9/eZHYZAAAXR5iBS6r/YWG80spadW3LxpEAgDMjzMAlvbT6O32x56iC/H303M29zC4HAODCCDNwORvzSzUve5ck6cnru6tjNH0yAIAzI8zApVira3Xfm7mqtxu6vlecfpPKwngAgJ9HmIHLMAxDj7y7RQdKqxQfGainfp3MTtgAgLNymzDz5z//WYmJiQoICFBqaqo+//xzs0tCE3s754D+tblIvt5eenEc68kAAM6NW4SZ5cuXKz09XY8++qhyc3N11VVX6dprr1V+fv7Znwy38N3hCj3x/jZJUsbwzmwgCQA4Z24RZubNm6c777xTv/vd79S1a1c9//zzio+P16JFi8wuDU2gpq5e9y7NVVVtvQZcGqVJV19idkkAADfi8mHGZrMpJydHw4cPb3R8+PDhWrdu3WmfU1NTI6vV2ugG1zVn5U5tL7IqMthf88b0ljf7LgEAzoPLh5kjR46ovr5eMTExjY7HxMSouLj4tM/JyspSWFiY4xYfH98cpeIC/GfHIf3f//ZKkp69uadiQgNMrggA4G5cPsyccPJVLYZhnPFKlxkzZqisrMxxKygoaI4ScZ5KrNV68O3NkqT/b0BH/bJLzFmeAQDAqXzNLuBsoqOj5ePjc8osTElJySmzNSdYLBZZLJbmKA8XyG439MBbefr+uE3d2oZq+rVdzC4JAOCmXH5mxt/fX6mpqcrOzm50PDs7W/379zepKlysl9fs0f92H1Wgn49evCVFFl8fs0sCALgpl5+ZkaSMjAxNmDBBaWlp6tevnxYvXqz8/HxNmjTJ7NJwAXLzS/Xcqp2SGrYruLRNK5MrAgC4M7cIM2PHjtXRo0c1a9YsFRUVKTk5WR999JE6dOhgdmk4T9bqWt23LFd1dkOjerbVzWlsVwAAuDhuEWYkafLkyZo8ebLZZeAiGIahx1ZsVcH3VWofEajZv+7BdgUAgIvm8j0z8BzvbDyoDzYVysfbSy+MS1FYINsVAAAuHmEGzWLP4Qo9/v5WSVLGsM5K7cB2BQCApkGYgdPV1NXr3jdzVWmrV79OUZo0kO0KAABNhzADp5v7753aVmhVRJCf5o/tLR+2KwAANCHCDJzqsx0lenVtw3YFz/yml2LD2K4AANC0CDNwmobtCjZJkib276ih3diuAADQ9AgzcAq73VDGW5t09LhNXWJD2K4AAOA0brPODNzL4s/3aO3uIwr089HC8SkK8Lvw7QqC/H217+mRTVgdAMCTMDODJpdXcEzPftywXcHM67vp0jYhJlcEAPBkhBk0qfLqWt33ZsN2BSN7ttWYtHizSwIAeDjCDJqMYRh67L2tyv++Uu3CA5XJdgUAgGZAmEGTeXfjQb2f17BdwYu39Ga7AgBAsyDMoEnsOVyhP/6wXcEDQ5OU2iHS5IoAAC0FYQYXzVZn133LGrYr+EWnSP1h0KVmlwQAaEEIM7hoz3y8Q1sPWhUe5Kfnx6awXQEAoFkRZnBR/ruzRK98znYFAADzEGZwwUrKf9yu4I5+HTSM7QoAACYgzOCC2O2Gpr61SUcqGrYrmHFdV7NLAgC0UIQZXJC/rN2jz789ogA/by245eK2KwAA4GIQZnDeNhUc09x/N2xX8MTo7kqKYbsCAIB5CDM4L+XVtbpvWcN2Bdf1iNW4vmxXAAAwF2EG5+Xx97dp/9GG7Qqyft2T7QoAAKYjzOCcvbvxgFbkHpS3l/TCuN4KC2K7AgCA+QgzLUylrU4dp3+ojtM/VKWt7pyft/fIcf3xvYbtCtKHdlZaR7YrAAC4BsIMzspWZ9d9b+bquK1elydGaspgtisAALgOwgzO6tlVO7XlYJnCg/z0wrjebFcAAHAphBn8rNW7Dmvxmj2SpDk39VTbsECTKwIAoDHCDM7ocHmNpr6VJ0ma8IsOGtE91tyCAAA4DcIMTstuNzT17YbtCi6LCdGjI9muAADgmggzOK1X1+7Vml2HZfH11oLxbFcAAHBdhBmcYvOBY5r78Q5J0uOju6kz2xUAAFwYYQaNVNTU6b43c1Vbb+ia7rEaf3mC2SUBAPCzCDNo5PH3t2rf0UrFhQXo6Zt6sF0BAMDlEWbgsCL3gN7d2LBdwfPjUhQe5G92SQAAnBVhBpKk/UeP67EVDdsV3D+ksy5PZLsCAIB7IMzglO0K7vkl2xUAANwHYQZ6LnunNh0oU1ign54fy3YFAAD3Qphp4dbsOqyXV/+4XUFcONsVAADcC2GmBTtaUaOMtzZJkm77RYKuSWa7AgCA+yHMtGAz3t2qIxU16hzTSo+N7GZ2OQAAXBDCTAu2dveRhu0KbunDdgUAALdFmGnh/jiqmy6LZbsCAID7Isy0MBXVdY6fh3Zto1uvYLsCAIB7I8y0IAePVWnC/33tuD/rV93ZrgAA4PacFmb27dunO++8U4mJiQoMDNQll1yiJ554QjabrdG4/Px8jR49WsHBwYqOjtZ99913ypgtW7Zo4MCBCgwMVLt27TRr1iwZhuGs0j1Szv5S/WrhWu0sLnccY7sCAIAn8HXWL96xY4fsdrtefvllXXrppdq6davuuusuHT9+XM8++6wkqb6+XiNHjlTr1q21du1aHT16VHfccYcMw9CCBQskSVarVcOGDdPgwYO1fv167dq1SxMnTlRwcLCmTp3qrPI9yorcA3r4H1tkq7frsphW2nmowuySAABoMl5GM05xPPPMM1q0aJH27GlYpG3lypUaNWqUCgoKFBcXJ0latmyZJk6cqJKSEoWGhmrRokWaMWOGDh06JIvFIkl6+umntWDBAh04cOCcTpNYrVaFhYWprKxMoaGhznuBLsZuN/Tsqp3683+/kyQN6xajzF8nq+/sTyVJ22eNUJC/0/IsAAAX5Vy/v5u1Z6asrEyRkT9uYPjFF18oOTnZEWQkacSIEaqpqVFOTo5jzMCBAx1B5sSYwsJC7du377R/p6amRlartdHNGSptdeo4/UN1nP6hKm11Z39CMzpeU6dJb+Q4gswfBl2il29LVbCF8AIA8CzNFma+++47LViwQJMmTXIcKy4uVkxMTKNxERER8vf3V3Fx8RnHnLh/YszJsrKyFBYW5rjFx8c35UtxeQePVek3L32hVdsPyd/HW/PG9NLD13SRN3suAQA80HmHmZkzZ8rLy+tnbxs2bGj0nMLCQl1zzTW6+eab9bvf/a7RY6c7TWQYRqPjJ485cWbsTKeYZsyYobKyMsetoKDgfF+m22po9P2fvimyKrqVv978/S90Y5/2ZpcFAIDTnPc5h3vuuUfjxo372TEdO3Z0/FxYWKjBgwerX79+Wrx4caNxsbGx+uqrrxodKy0tVW1trWP2JTY29pQZmJKSEkk6ZcbmBIvF0ui0VEvxXu5BTXtns2x1dnWJDdFf7khT+4ggs8sCAMCpzjvMREdHKzo6+pzGHjx4UIMHD1Zqaqr++te/ytu78URQv379NHv2bBUVFalt27aSpFWrVslisSg1NdUx5pFHHpHNZpO/v79jTFxcXKPQ1JLZ7Yaey96pP332Y6Pv82N70x8DAGgRnNYzU1hYqEGDBik+Pl7PPvusDh8+rOLi4kazLMOHD1e3bt00YcIE5ebm6tNPP9WDDz6ou+66y9G1PH78eFksFk2cOFFbt27VihUrlJmZqYyMDBZ8U0Oj7x/+nuMIMjT6AgBaGqd9461atUq7d+/W7t271b59456NEz0vPj4++vDDDzV58mQNGDBAgYGBGj9+vGMdGkkKCwtTdna2pkyZorS0NEVERCgjI0MZGRnOKt1tFB6r0u9e36DtRVb5+3jr6Zt60B8DAGhxnBZmJk6cqIkTJ551XEJCgv71r3/97JgePXpozZo1TVSZZ9iYX6rf/y1HRypqFN3KXy9PSFVqh8izPxEAAA/DuQg39H7eQT30Dxp9AQCQCDNu5eRG36FdY/TCOBp9AQAtG9+CbqLSVqcHlufp422HJDU0+j40/DIWwgMAtHiEGTdAoy8AAGdGmHFxufmluotGXwAAzogw48Jo9AUA4OwIMy6IRl8AAM4d344u5uRG30kDL9G0ETT6AgBwJoQZF3Jyo2/WjT10UyqNvgAA/BzCjIv4aaNvVLC/Ft/unEbfIH9f7Xt6ZJP/XgAAzEKYcQE0+gIAcOEIMyay2w3Ny96lhZ/tltTQ6Pv8uN5qRaMvAADnjG9Nk1Ta6pSxfJP+va1YUkOj70MjLpMPjb4AAJwXwowJCo9V6a6/bdC2Qhp9AQC4WISZZpabX6rfL8nR4XLnNvoCANBSEGaaEY2+AAA0PcJMM7DbDc3/ZJcW/IdGXwAAmhrfpk5WaavT1Lc2aeXWhkbfuwd20rQRXWj0BQCgiRBmnKiorGFF3xONvpk39tBvaPQFAKBJEWacJK/gmO762wZHo+/LE1KV1pFGXwAAmhphxgk+2FSoh97epJofGn1fuT1N8ZE0+gIA4AyEmSZktxt6/pNdetHR6NtGz49LodEXAAAn4lu2idDoCwCAOQgzTeT2V9drexGNvgAANDfCTBPZXmSl0RcAABMQZi7CR1uKHD93jmmlV+/oS6MvAADNjDBzEfIKjjl+/vvvrlDrkADzigEAoIXyNrsAdzZtxGWOn4O5YgkAAFMQZi6Crw9vHwAAZuPbGAAAuDXCDAAAcGuEGQAA4NYIMwAAwK0RZgAAgFsjzAAAALdGmAEAAG6NMAMAANwaYQYAALg1wgwAAHBrhBkAAODWCDMAAMCtEWYAAIBbI8wAAAC31ixhpqamRr1795aXl5fy8vIaPZafn6/Ro0crODhY0dHRuu+++2Sz2RqN2bJliwYOHKjAwEC1a9dOs2bNkmEYzVE6AABwcb7N8UemTZumuLg4bdq0qdHx+vp6jRw5Uq1bt9batWt19OhR3XHHHTIMQwsWLJAkWa1WDRs2TIMHD9b69eu1a9cuTZw4UcHBwZo6dWpzlA8AAFyY08PMypUrtWrVKr3zzjtauXJlo8dWrVql7du3q6CgQHFxcZKk5557ThMnTtTs2bMVGhqqv//976qurtZrr70mi8Wi5ORk7dq1S/PmzVNGRoa8vLyc/RIAAIALc+pppkOHDumuu+7SkiVLFBQUdMrjX3zxhZKTkx1BRpJGjBihmpoa5eTkOMYMHDhQFoul0ZjCwkLt27fvtH+3pqZGVqu10Q0AAHgmp4UZwzA0ceJETZo0SWlpaacdU1xcrJiYmEbHIiIi5O/vr+Li4jOOOXH/xJiTZWVlKSwszHGLj4+/2JcDAABc1HmHmZkzZ8rLy+tnbxs2bNCCBQtktVo1Y8aMn/19pztNZBhGo+MnjznR/HumU0wzZsxQWVmZ41ZQUHC+LxMAALiJ8+6ZueeeezRu3LifHdOxY0c99dRT+vLLLxudHpKktLQ03XrrrXr99dcVGxurr776qtHjpaWlqq2tdcy+xMbGnjIDU1JSIkmnzNicYLFYTvm7AADAM513mImOjlZ0dPRZx7344ot66qmnHPcLCws1YsQILV++XFdccYUkqV+/fpo9e7aKiorUtm1bSQ1NwRaLRampqY4xjzzyiGw2m/z9/R1j4uLi1LFjx/MtHwAAeBin9cwkJCQoOTnZcevcubMk6ZJLLlH79u0lScOHD1e3bt00YcIE5ebm6tNPP9WDDz6ou+66S6GhoZKk8ePHy2KxaOLEidq6datWrFihzMxMrmQCAACSTF4B2MfHRx9++KECAgI0YMAAjRkzRjfccIOeffZZx5iwsDBlZ2frwIEDSktL0+TJk5WRkaGMjAwTKwcAAK6iWRbNkxr6aE63am9CQoL+9a9//exze/TooTVr1jirNAAA4MbYmwkAALg1wgwAAHBrhBkAAODWCDMAAMCtEWYAAIBbI8wAAAC31myXZnuiIH9f7Xt6pNllAADQojEzAwAA3BphBgAAuDXCDAAAcGuEGQAA4NYIMwAAwK0RZgAAgFsjzAAAALdGmAEAAG6NMAMAANwaYQYAALg1wgwAAHBrhBkAAODWCDMAAMCtEWYAAIBbI8wAAAC35mt2Ac3BMAxJktVqNbkSAABwrk58b5/4Hj+TFhFmysvLJUnx8fEmVwIAAM5XeXm5wsLCzvi4l3G2uOMB7Ha7CgsLFRISIi8vryb93VarVfHx8SooKFBoaGiT/m78iPe5efA+Nw/e5+bB+9w8nPk+G4ah8vJyxcXFydv7zJ0xLWJmxtvbW+3bt3fq3wgNDeU/lmbA+9w8eJ+bB+9z8+B9bh7Oep9/bkbmBBqAAQCAWyPMAAAAt0aYuUgWi0VPPPGELBaL2aV4NN7n5sH73Dx4n5sH73PzcIX3uUU0AAMAAM/FzAwAAHBrhBkAAODWCDMAAMCtEWYAAIBbI8xchD//+c9KTExUQECAUlNT9fnnn5tdkkfJyspS3759FRISojZt2uiGG27Qzp07zS7L42VlZcnLy0vp6elml+KRDh48qNtuu01RUVEKCgpS7969lZOTY3ZZHqWurk6PPfaYEhMTFRgYqE6dOmnWrFmy2+1ml+bW1qxZo9GjRysuLk5eXl567733Gj1uGIZmzpypuLg4BQYGatCgQdq2bVuz1EaYuUDLly9Xenq6Hn30UeXm5uqqq67Stddeq/z8fLNL8xirV6/WlClT9OWXXyo7O1t1dXUaPny4jh8/bnZpHmv9+vVavHixevbsaXYpHqm0tFQDBgyQn5+fVq5cqe3bt+u5555TeHi42aV5lDlz5uill17SwoUL9c0332ju3Ll65plntGDBArNLc2vHjx9Xr169tHDhwtM+PnfuXM2bN08LFy7U+vXrFRsbq2HDhjn2R3QqAxfk8ssvNyZNmtToWJcuXYzp06ebVJHnKykpMSQZq1evNrsUj1ReXm4kJSUZ2dnZxsCBA43777/f7JI8zsMPP2xceeWVZpfh8UaOHGn89re/bXTsxhtvNG677TaTKvI8kowVK1Y47tvtdiM2NtZ4+umnHceqq6uNsLAw46WXXnJ6PczMXACbzaacnBwNHz680fHhw4dr3bp1JlXl+crKyiRJkZGRJlfimaZMmaKRI0dq6NChZpfisT744AOlpaXp5ptvVps2bZSSkqJXXnnF7LI8zpVXXqlPP/1Uu3btkiRt2rRJa9eu1XXXXWdyZZ5r7969Ki4ubvS9aLFYNHDgwGb5XmwRG002tSNHjqi+vl4xMTGNjsfExKi4uNikqjybYRjKyMjQlVdeqeTkZLPL8TjLli3Txo0btX79erNL8Wh79uzRokWLlJGRoUceeURff/217rvvPlksFt1+++1ml+cxHn74YZWVlalLly7y8fFRfX29Zs+erVtuucXs0jzWie++030v7t+/3+l/nzBzEby8vBrdNwzjlGNoGvfcc482b96stWvXml2KxykoKND999+vVatWKSAgwOxyPJrdbldaWpoyMzMlSSkpKdq2bZsWLVpEmGlCy5cv1xtvvKGlS5eqe/fuysvLU3p6uuLi4nTHHXeYXZ5HM+t7kTBzAaKjo+Xj43PKLExJSckpqRQX795779UHH3ygNWvWqH379maX43FycnJUUlKi1NRUx7H6+nqtWbNGCxcuVE1NjXx8fEys0HO0bdtW3bp1a3Ssa9eueuedd0yqyDM99NBDmj59usaNGydJ6tGjh/bv36+srCzCjJPExsZKapihadu2reN4c30v0jNzAfz9/ZWamqrs7OxGx7Ozs9W/f3+TqvI8hmHonnvu0bvvvqv//Oc/SkxMNLskjzRkyBBt2bJFeXl5jltaWppuvfVW5eXlEWSa0IABA05ZXmDXrl3q0KGDSRV5psrKSnl7N/568/Hx4dJsJ0pMTFRsbGyj70WbzabVq1c3y/ciMzMXKCMjQxMmTFBaWpr69eunxYsXKz8/X5MmTTK7NI8xZcoULV26VO+//75CQkIcM2FhYWEKDAw0uTrPERISckofUnBwsKKiouhPamIPPPCA+vfvr8zMTI0ZM0Zff/21Fi9erMWLF5tdmkcZPXq0Zs+erYSEBHXv3l25ubmaN2+efvvb35pdmlurqKjQ7t27Hff37t2rvLw8RUZGKiEhQenp6crMzFRSUpKSkpKUmZmpoKAgjR8/3vnFOf16KQ/2pz/9yejQoYPh7+9v9OnTh0uGm5ik097++te/ml2ax+PSbOf55z//aSQnJxsWi8Xo0qWLsXjxYrNL8jhWq9W4//77jYSEBCMgIMDo1KmT8eijjxo1NTVml+bWPvvss9N+Jt9xxx2GYTRcnv3EE08YsbGxhsViMa6++mpjy5YtzVKbl2EYhvMjEwAAgHPQMwMAANwaYQYAALg1wgwAAHBrhBkAAODWCDMAAMCtEWYAAIBbI8wAAAC3RpgBAABujTADAADcGmEGAAC4NcIMAABwa4QZAADg1v5/UITx/3HITVAAAAAASUVORK5CYII=",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "x_data = []\n",
+ "y_data = []\n",
+ "ydataerr = []\n",
+ "times = [datetime.now()]\n",
+ "\n",
+ "def progress(it, metrics):\n",
+ " times.append(datetime.now())\n",
+ " x_data.append(it)\n",
+ " y_data.append(metrics['eval/episode_reward'])\n",
+ " ydataerr.append(metrics['eval/episode_reward_std'])\n",
+ "\n",
+ "env = envs.get_environment(\"anymal\", **env_kwargs)\n",
+ "eval_env = envs.get_environment(\"anymal\", **env_kwargs)\n",
+ "\n",
+ "make_inference_fn, params, _= train_fn(environment=env,\n",
+ " progress_fn=progress,\n",
+ " eval_env=eval_env)\n",
+ "\n",
+ "plt.errorbar(x_data, y_data, yerr=ydataerr)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "demo_env = envs.training.EpisodeWrapper(env, \n",
+ " episode_length=1000, \n",
+ " action_repeat=1)\n",
+ "\n",
+ "render_rollout(\n",
+ " jax.jit(demo_env.reset),\n",
+ " jax.jit(demo_env.step),\n",
+ " jax.jit(make_inference_fn(params)),\n",
+ " demo_env,\n",
+ " n_steps=200,\n",
+ " camera=\"track\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**A note on sample efficiency**\n",
+ "\n",
+ "Let's compare with PPO, again using 1e7 samples:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Text(0, 0.5, 'reward per episode')"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "train_fn = functools.partial(\n",
+ " ppo.train, num_timesteps=10_000_000, num_evals=10, reward_scaling=0.1,\n",
+ " episode_length=1000, normalize_observations=True, action_repeat=1,\n",
+ " unroll_length=10, num_minibatches=32, num_updates_per_batch=8,\n",
+ " discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=1024,\n",
+ " batch_size=1024, seed=0)\n",
+ "\n",
+ "x_data = []\n",
+ "y_data = []\n",
+ "ydataerr = []\n",
+ "\n",
+ "env = envs.get_environment(\"anymal\", **env_kwargs)\n",
+ "\n",
+ "def progress(num_steps, metrics):\n",
+ " x_data.append(num_steps)\n",
+ " y_data.append(metrics['eval/episode_reward'])\n",
+ " ydataerr.append(metrics['eval/episode_reward_std'])\n",
+ "\n",
+ "make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)\n",
+ "\n",
+ "plt.errorbar(x_data, y_data, yerr=ydataerr)\n",
+ "plt.xlabel('# environment steps')\n",
+ "plt.ylabel('reward per episode')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We see that PPO struggles to learn locomotion in this setup, even with over 10x the number of simulator steps. \n",
+ "\n",
+ "Rather than indicating a shortcoming of PPO, this study demonstrates a policy-learning setup that effectively leverages FoPG methods. It involves learning small, accurate perturbations from a good baseline - optimizing to the local minima in a deep valley. FoPG's are precise enough to guide these perturbations using nuanced reward signals, such as our foot placement spline.\n",
+ "\n",
+ "In contrast, RL algorithms such as PPO benefit from [policy-learning setups](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb) that have less structured rewards. Unlike FoPG methods, they [benefit greatly](https://www.science.org/doi/abs/10.1126/scirobotics.adg1462) from sparse, non-differentiable rewards such as a large penalty for falling."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "mujoco",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}