-
Notifications
You must be signed in to change notification settings - Fork 695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add tianshou-like JAX+PPO+Mujoco #355
base: master
Are you sure you want to change the base?
Conversation
The latest updates on your projects. Learn more about Vercel for Git ↗︎
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @quangr, thanks for this awesome contribution! Being able to use JAX+PPO+MuJoCo+EnvPool will be a game-changer for a lot of folks! This PR will also make #217 not necessary.
Some comments and thoughts:
- Would you mind sharing your wandb username so I can add you to the
openrlbenchmark
entity? It would be great if you could contribute tracked experiments there, and we can use our CLI utility (https://github.com/openrlbenchmark/openrlbenchmark) to plot charts. Could you share your huggingface username and help add saved models?- On a second thought there might not be a way to render mujoco images with Envpool. Don't worry about this yet.
We recently added the huggingface integration as followscleanrl/cleanrl/ppo_atari_envpool_xla_jax_scan.py
Lines 477 to 511 in d0d6bae
if args.save_model: model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" with open(model_path, "wb") as f: f.write( flax.serialization.to_bytes( [ vars(args), [ agent_state.params.network_params, agent_state.params.actor_params, agent_state.params.critic_params, ], ] ) ) print(f"model saved to {model_path}") from cleanrl_utils.evals.ppo_envpool_jax_eval import evaluate episodic_returns = evaluate( model_path, make_env, args.env_id, eval_episodes=10, run_name=f"{run_name}-eval", Model=(Network, Actor, Critic), ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx) if args.upload_model: from cleanrl_utils.huggingface import push_to_hub repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval") You can load the trained model by runningpython -m cleanrl_utils.enjoy --exp-name ppo_atari_envpool_xla_jax_scan --env-id Breakout-v5
-
def compute_gae_once(carry, inp, gamma, gae_lambda): | ||
advantages = carry | ||
nextdone, nexttruncated, nextvalues, curvalues, reward = inp | ||
nextnonterminal = (1.0 - nextdone) * (1.0 - nexttruncated) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gym's truncation / termination has been a mess, so it's confusing to handle value estimation on truncation correctly.
If I understood correctly, if you want to handle value estimation correctly, you should use nextnonterminal = (1.0 - nextdone)
which is equivalent to nextnonterminal = (1.0 - next_terminated)
under env_type="gymnasium",
If you don't (which is fine so that this implementation is consistent with other PPO variants), then the current implementation is fine :)
If you choose to handle it correctly, make sure to add a note to the implementation details section of the docs, since it's a deviation from the original implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious. Are we currently using gymnasium? My last PR still uses the normal gym interface. This the version bump a dependency of any of these libs?
if args.rew_norm: | ||
returns = (returns / jnp.sqrt(agent_state.ret_rms.var + 10e-8)).astype(jnp.float32) | ||
agent_state = agent_state.replace(ret_rms=agent_state.ret_rms.update(returns.flatten())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. Going through the logic here I realize this reward normalization is quite different from the original implementation, which may or may not be fine but it's worth pointing out the difference.
The original implementation does a "forward discounted return" and normalizes the reward on a per-step basis.
See source1 and source2
Here what happens is that the returns are normalized only after the rollout phase, so the rewards themselves are not normalized.
It doesn't feel like these two approaches are equivalent. If you want to get to the bottom of this, it might be worth it to conduct an empirical study comparing this implementation and #348.
Both approaches seem fine to me, but we should document if we choose the current apporach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also noticed there is no observation normalization. Is this intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for a nice PR! Still, there are some unjustified changes which might cause performance difference vs other versions. In general, aim to reduce the number of code changes (see #328 (comment) for how to compare your pr with a reference implementation). For your case, it might be a good idea to compare to ppo_continuous_action.py
and ppo_atari_envpool_xla_jax_scan.py
. A good rule of thumb is to stay close to these references as much as possible.
Thank @51616 and @vwxyzjn for code reviewing❤️! My wandb username is quangr. I will check these comment and improve on my code soon. |
I have submit new commits for most of the comments, and here are my answers to some other comments. If there is something missing, please help let me know
I'm trying to mask the done value because tianshou do so https://github.com/thu-ml/tianshou/blob/774d3d8e833a1b1c1ed320e0971ab125161f4264/tianshou/policy/base.py#L288. you're right i'm putting it to compute_gae_once function
The xla api provide by envpool is not a pure function. The handle passing to send function is just a fat pointer point to envpool class. When we keep all state inside a handle tuple, if we reset the environment, the pointer remains unchange, Other parts (like new statistics state) also requries a reset state. So I think there must be a change towards envpool API. In order to have a less confusing API, maybe we can remove handle from return values of envs.xla(). I can't think of way to make things consistent with envpool for now.
cleanrl/cleanrl/ppo_continuous_action_envpool_xla_jax_scan.py Lines 313 to 317 in c79d66f
There is a Observation Normalization, it was implemented as wrapper. And I think the Observation Normalization is mandatory to acheving high score in mujoco env. And in this Observation Normalization Wrapper I actually turn the gym api into gymnasium api: cleanrl/cleanrl/ppo_continuous_action_envpool_xla_jax_scan.py Lines 177 to 187 in c79d66f
This is because when I writing this code, I use envpool lastest(0.8.1) version, and it use gymnasium api.
This reward normalization is bizarre to me too, but this is how tianshou implemente it, and it really works. HalfCheetah-v3 Hooper-v3 |
I have run experiment for Ant-v4 HalfCheetah-v4 Hopper-v4 Walker2d-v4 Swimmer-v4 Humanoid-v4 Reacher-v4 InvertedPendulum-v4 InvertedDoublePendulum-v4. Here is the report https://wandb.ai/openrlbenchmark/cleanrl/reports/MuJoCo-jax-EnvPool--VmlldzozNDczNDkz. comparing the result to tianshouI notice that tianshou use 10 envs to evaluate performance from reset state every epoch as their benchmark, I wonder if it's a problem for comparing. comparing with ppo_continuous_action_8Mbetter : Ant-v4 HalfCheetah-v4 as for InvertedDoublePendulum-v4 InvertedPendulum-v4, every agents in my version reach 1000 score, which not happens in ppo_continuous_action_8M. But it start to decline afterward, and in tainshou's training data we can observe same decline curve: https://drive.google.com/drive/folders/1tQvgmsBbuLPNU3qo5thTBi03QzGXygXf |
Thanks for running the results! They look great. Feel free to click Meanwhile, you might find the following tool handy
Which generates and the wandb report: Couple of notes:
See
The hyperparams used is in https://github.com/vwxyzjn/envpool-cleanrl/blob/880552a168c08af334b5e5d8868bfbe5ea881445/ppo_continuous_action_envpool.py#L40-L75, with no obs nor reward normalization. |
Thanks for updating me. I'll be ready for the code review once the documentation is finished. I'm also happy to run the experiment you suggested. |
I have document the questions you brought up and I am now ready for the code review. I would be happy to hear your feedback. |
Description
Add tianshou-like JAX+PPO+Mujoco code, which is tested in Hopper-v3 and HalfCheetah-v3.
11 seed test
Hopper-v3 (Tianshou 1M:2609.3+-700.8 ; 3M:3127.7+-413.0)
![Hopper-v3](https://user-images.githubusercontent.com/22930473/215747610-a5a611e2-083b-4e0f-a03b-b6a449a2ad4e.png)
my result:
HalfCheetah-v3 (Tianshou 1M:5783.9+-1244.0 ; 3M:7337.4+-1508.2)
![HalfCheetah-v3](https://user-images.githubusercontent.com/22930473/215747604-8a812d7a-7b3b-4ead-bf67-1fa9f17d6ec0.png)
my result:
This implementation uses a customized
EnvWrapper
class to wrap environment. Different from traditional Gym-type wrap which hasstep
andreset
method.EnvWrapper
requires three methodsrecv
,send
andreset
, these methods need to be pure functions in order to be transformed in jax. Therecv
method will modify what env received after an action step, and thesend
method will modify the action send to env.Types of changes
Checklist:
pre-commit run --all-files
passes (required).mkdocs serve
.If you are adding new algorithm variants or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.
--capture-video
flag toggled on (required).mkdocs serve
.