Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated basic APG algorithm #476

Merged
merged 7 commits into from
Apr 18, 2024
Merged

Updated basic APG algorithm #476

merged 7 commits into from
Apr 18, 2024

Conversation

Andrew-Luo1
Copy link
Contributor

The goal of this proposed update is to provide a simple APG algorithm that can solve non-trivial tasks, as a first step for researchers and practitioners to explore Brax's differentiable simulation. It has been tested on MJX. Notes:

  • A demonstration of this algorithm is shown here.
  • I have not benchmarked this algorithm against Brax's RL algorithms such as PPO and SAC, since a) the environments would need differentiable rewards, and b) this simple implementation aims to be a basic reference to extend upon.

1: Algorithm Update

This fork contains an APG implementation that is about as simple as the current one, but reflects a common thread between several recent results that have used differentiable simulators to achieve locomotion: 1, 2, 3.

Brax's current APG algorithm is roughly equivalent to the following pseudocode:

for i in range(n_epochs):
    reset state
    policy_grads = []
    for j in range(episode_length // short_horizon))
        state, policy_grad = jax.grad(unroll(state, policy, short_horizon))
        policy_grads.append(policy_grad)
    optimizer.step(mean(policy_grads))

In contrast, the cited results update the policy gradient much more frequently, using the observation that policy gradients that differentiate through the simulator have low variance. Hence, unrolling for an entire episode before updating has limited use. That additional samples past a certain point do not help is seen in that convergence does not increase with with massive parallelization [2]. The proposed APG algorithm essentially performs live stochastic gradient descent on the policy, unrolling it for a short window, doing a gradient update, then continuing where it left off:

reset state
for i in range(n_epochs):
    state, policy_grad = jax.grad(unroll(state, policy, short_horizon))
    optimizer.step(policy_grad)

Note that n_epochs can be much larger in this case. This modification allows the algorithm to relatively quickly learn quadruped locomotion, albeit with a particular training pipeline and reward design. (Notebook)

Additional notes:

  • This fork uses a learning schedule to improve training stability.
  • The particular choices of Adam optimizer parameters come from [2] and significantly improve training outcomes.

2: Supporting Updates

Configurable initial policy variance: When hotstarting a policy, it benefits to explore around its induced state-space vicinity. This can be done by initializing the policy network weights small. Currently, the softplus disables this possibility, so this fork adds a scaling parameter.

Layer norm: I have found that using layer normalization in the policy neural network has greatly improved the training stability of APG methods and is seen in other implementations.

Copy link

google-cla bot commented Apr 15, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Collaborator

@erikfrey erikfrey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, just one nit.

@@ -131,7 +131,7 @@ def forward_log_det_jacobian(self, x):
class NormalTanhDistribution(ParametricDistribution):
"""Normal distribution followed by tanh."""

def __init__(self, event_size, min_std=0.001):
def __init__(self, event_size, min_std=0.001, var_scale=1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add var_scale to Args section of docstring

@erikfrey
Copy link
Collaborator

Oh - it looks like you need to update brax/training/agents/apg/train_test.py - changes should hopefully be minimal - please review the failing test.

Also, please do sign the CLA. Thank you!

@Andrew-Luo1
Copy link
Contributor Author

Hi @erikfrey, I have updated the tests and they pass on my local setup. I've also fixed the nit. I've signed the CLA, and the Checks tab is saying that my signing went through. Please let me know if there's anything missing.

Copy link
Collaborator

@erikfrey erikfrey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there!

@@ -22,6 +22,9 @@
from brax.training.agents.apg import networks as apg_networks
from brax.training.agents.apg import train as apg
import jax
from jax import config
config.update("jax_enable_x64", True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these config changes are process-wide, so it's breaking other tests that expect default float width and precision. we run all the tests in a single process.

the envs in this test are simple enough that hopefully they don't need these config changes. can you try removing these config changes or otherwise tweak the tests so that jax_enable_64 and the precision change are not needed?

@Andrew-Luo1
Copy link
Contributor Author

I removed the double precision toggle and the tests still run fine on my local setup. Let's see if this works :)

@erikfrey
Copy link
Collaborator

Amazing, thank you!

@erikfrey erikfrey merged commit b45760c into google:main Apr 18, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants