From 6602c57646ce2dda5347d01642a0abb1b13912bb Mon Sep 17 00:00:00 2001 From: RobertTLange Date: Thu, 18 Nov 2021 12:30:50 +0100 Subject: [PATCH] Experimental `EnvRollout` --- README.md | 2 +- examples/getting_started.ipynb | 83 ++++++++++++++++++++++++------- _version.py => gymnax/_version.py | 0 gymnax/experimental/rollout.py | 78 +++++++++++++++++++++++++++++ 4 files changed, 144 insertions(+), 19 deletions(-) rename _version.py => gymnax/_version.py (100%) create mode 100644 gymnax/experimental/rollout.py diff --git a/README.md b/README.md index 9915990..f5902b7 100755 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ rng, key_reset, key_policy, key_step = jax.random.split(rng, 4) env, env_params = gymnax.make("Pendulum-v1") obs, state = env.reset(key_reset, env_params) -action = env.action_space.sample(key_policy) +action = env.action_space(env_params).sample(key_policy) n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params) ``` diff --git a/examples/getting_started.ipynb b/examples/getting_started.ipynb index da139a6..e8eefe1 100644 --- a/examples/getting_started.ipynb +++ b/examples/getting_started.ipynb @@ -5,7 +5,7 @@ "metadata": {}, "source": [ "# `gymnax`: Classic Gym Environments in JAX\n", - "### Author: [@RobertTLange](https://twitter.com/RobertTLange) [Last Update: October 2021][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/gymnax/blob/main/examples/getting_started.ipynb)\n", + "### Author: [@RobertTLange](https://twitter.com/RobertTLange) [Last Update: November 2021][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/gymnax/blob/main/examples/getting_started.ipynb)\n", "" ] }, @@ -13,16 +13,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Basic API: `gymnax.make()`, `env.reset()`, `env.step()`" + "## Basic API: `gymnax.make()`, `env.reset()`, `env.step()`" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import jax\n", + "import jax.numpy as jnp\n", "import gymnax\n", "\n", "rng = jax.random.PRNGKey(0)\n", @@ -31,13 +32,13 @@ "env, env_params = gymnax.make(\"Pendulum-v1\")\n", "\n", "obs, state = env.reset(key_reset, env_params)\n", - "action = env.action_space.sample(key_policy)\n", + "action = env.action_space(env_params).sample(key_policy)\n", "n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -48,12 +49,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Jitted Episode Rollouts via `lax.scan`" + "## Jitted Episode Rollouts via `lax.scan`" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -68,7 +69,7 @@ " num_output_units: int\n", "\n", " @nn.compact\n", - " def __call__(self, x):\n", + " def __call__(self, x, rng):\n", " for l in range(self.num_hidden_layers):\n", " x = nn.Dense(features=self.num_hidden_units)(x)\n", " x = nn.relu(x)\n", @@ -77,12 +78,12 @@ " \n", "\n", "network = MLP(48, 1, 1)\n", - "policy_params = network.init(rng, jnp.zeros(3))[\"params\"]" + "policy_params = network.init(rng, jnp.zeros(3), None)[\"params\"]" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -95,8 +96,8 @@ " def policy_step(state_input, tmp):\n", " \"\"\"lax.scan compatible step transition in jax env.\"\"\"\n", " obs, state, policy_params, rng = state_input\n", - " rng, rng_step = jax.random.split(rng)\n", - " action = network.apply({\"params\": policy_params}, obs)\n", + " rng, rng_step, rng_net = jax.random.split(rng, 3)\n", + " action = network.apply({\"params\": policy_params}, obs, rng_net)\n", " next_o, next_s, reward, done, _ = env.step(\n", " rng_step, state, action, env_params\n", " )\n", @@ -118,15 +119,54 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray(-1600.6174, dtype=float32)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Jit-Compiled Episode Rollout\n", "jit_rollout = jax.jit(rollout, static_argnums=3)\n", "jit_rollout(rng, policy_params, env_params, 200)" ] }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray([-1076.0648, -1299.151 , -1502.7104, -1210.023 , -1287.3069,\n", + " -1267.0559, -1529.9363, -1517.0801, -1809.3643, -1565.3792], dtype=float32)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from gymnax.experimental.rollout import EnvRollout\n", + "\n", + "roller = EnvRollout(model_forward=network.apply,\n", + " env_name=\"Pendulum-v1\",\n", + " num_env_steps=200,\n", + " num_episodes=10)\n", + "\n", + "roller.collect(rng, policy_params)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -168,7 +208,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Distributed Anakin Agent on BSuite Catch Environment\n", + "# Distributed Anakin Agent\n", "\n", "\n", "Adapted from Hessel et al. (2021) and DeepMind's [Example Colab](https://colab.research.google.com/drive/1974D-qP17fd5mLxy6QZv-ic4yxlPJp-G?usp=sharing#scrollTo=lhnJkrYLOvcs)" @@ -418,6 +458,13 @@ "run_experiment(env, 128, 16, 1e-4, 100, 42)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Vectorized Population Evaluation for CMA-ES" + ] + }, { "cell_type": "code", "execution_count": null, @@ -428,9 +475,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python3 (mle-toolbox)", + "display_name": "snippets", "language": "python", - "name": "mle-toolbox" + "name": "snippets" }, "language_info": { "codemirror_mode": { @@ -442,7 +489,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.6" + "version": "3.8.11" } }, "nbformat": 4, diff --git a/_version.py b/gymnax/_version.py similarity index 100% rename from _version.py rename to gymnax/_version.py diff --git a/gymnax/experimental/rollout.py b/gymnax/experimental/rollout.py new file mode 100644 index 0000000..e0fef18 --- /dev/null +++ b/gymnax/experimental/rollout.py @@ -0,0 +1,78 @@ +import jax +import jax.numpy as jnp +import gymnax + + +class EnvRollout(object): + def __init__( + self, + model_forward, + env_name: str = "Pendulum-v1", + num_env_steps: int = 200, + num_episodes: int = 20, + ): + """Wrapper to define batch evaluation for generation parameters.""" + self.env_name = env_name + self.num_env_steps = num_env_steps + self.num_episodes = num_episodes + + # Define the RL environment & network forward function + self.env, self.env_params = gymnax.make(self.env_name) + self.model_forward = model_forward + + # Set up the generation evaluation vmap-ed function - rl/supervised/etc. + self.gen_evaluate = self.batch_evaluate() + + def collect(self, rng_eval, policy_params): + """Reshape parameter vector and evaluate the generation.""" + # Reshape the parameters into the correct network format + rollout_keys = jax.random.split(rng_eval, self.num_episodes) + + # Evaluate generation population on pendulum task - min cost! + pop_trajectories = self.gen_evaluate(rollout_keys, policy_params) + return pop_trajectories + + def batch_evaluate(self): + """Evaluate a generation of networks on RL/Supervised/etc. task.""" + # vmap over different MC fitness evaluations for single network + batch_rollout = jax.jit( + jax.vmap(self.rollout, in_axes=(0, None), out_axes=0) + ) + return batch_rollout + + def rollout(self, rng_input, policy_params): + """Rollout a pendulum episode with lax.scan.""" + # Reset the environment + rng_reset, rng_episode = jax.random.split(rng_input) + obs, state = self.env.reset(rng_reset, self.env_params) + + def policy_step(state_input, tmp): + """lax.scan compatible step transition in jax env.""" + obs, state, policy_params, rng = state_input + rng, rng_step, rng_net = jax.random.split(rng, 3) + action = self.model_forward({"params": policy_params}, obs, rng=rng_net) + next_o, next_s, reward, done, _ = self.env.step( + rng_step, state, action, self.env_params + ) + carry = [next_o.squeeze(), next_s, policy_params, rng] + y = [next_o.squeeze(), reward, done] + return carry, y + + # Scan over episode step loop + _, scan_out = jax.lax.scan( + policy_step, + [obs, state, policy_params, rng_episode], + [jnp.zeros((self.num_env_steps, self.input_shape[0] + 2))], + ) + # Return the sum of rewards accumulated by agent in episode rollout + obs, rewards, dones = scan_out[0], scan_out[1], scan_out[2] + rewards = rewards.reshape(self.num_env_steps, 1) + ep_mask = (jnp.cumsum(dones) < 1).reshape(self.num_env_steps, 1) + return obs, rewards, dones, jnp.sum(rewards * ep_mask) + + @property + def input_shape(self): + """Get the shape of the observation.""" + rng = jax.random.PRNGKey(0) + obs, state = self.env.reset(rng, self.env_params) + return obs.shape