Skip to content

Commit

Permalink
Experimental EnvRollout
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertTLange committed Nov 18, 2021
1 parent 8f77572 commit 6602c57
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ rng, key_reset, key_policy, key_step = jax.random.split(rng, 4)
env, env_params = gymnax.make("Pendulum-v1")

obs, state = env.reset(key_reset, env_params)
action = env.action_space.sample(key_policy)
action = env.action_space(env_params).sample(key_policy)
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)
```

Expand Down
83 changes: 65 additions & 18 deletions examples/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,25 @@
"metadata": {},
"source": [
"# `gymnax`: Classic Gym Environments in JAX\n",
"### Author: [@RobertTLange](https://twitter.com/RobertTLange) [Last Update: October 2021][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/gymnax/blob/main/examples/getting_started.ipynb)\n",
"### Author: [@RobertTLange](https://twitter.com/RobertTLange) [Last Update: November 2021][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/gymnax/blob/main/examples/getting_started.ipynb)\n",
"<a href=\"https://github.com/RobertTLange/gymnax/blob/main/docs/gymnax_logo.png?raw=true\"><img src=\"https://github.com/RobertTLange/gymnax/blob/main/docs/gymnax_logo.png?raw=true\" width=\"200\" align=\"right\" /></a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Basic API: `gymnax.make()`, `env.reset()`, `env.step()`"
"## Basic API: `gymnax.make()`, `env.reset()`, `env.step()`"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import gymnax\n",
"\n",
"rng = jax.random.PRNGKey(0)\n",
Expand All @@ -31,13 +32,13 @@
"env, env_params = gymnax.make(\"Pendulum-v1\")\n",
"\n",
"obs, state = env.reset(key_reset, env_params)\n",
"action = env.action_space.sample(key_policy)\n",
"action = env.action_space(env_params).sample(key_policy)\n",
"n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -48,12 +49,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Jitted Episode Rollouts via `lax.scan`"
"## Jitted Episode Rollouts via `lax.scan`"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -68,7 +69,7 @@
" num_output_units: int\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" def __call__(self, x, rng):\n",
" for l in range(self.num_hidden_layers):\n",
" x = nn.Dense(features=self.num_hidden_units)(x)\n",
" x = nn.relu(x)\n",
Expand All @@ -77,12 +78,12 @@
" \n",
"\n",
"network = MLP(48, 1, 1)\n",
"policy_params = network.init(rng, jnp.zeros(3))[\"params\"]"
"policy_params = network.init(rng, jnp.zeros(3), None)[\"params\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -95,8 +96,8 @@
" def policy_step(state_input, tmp):\n",
" \"\"\"lax.scan compatible step transition in jax env.\"\"\"\n",
" obs, state, policy_params, rng = state_input\n",
" rng, rng_step = jax.random.split(rng)\n",
" action = network.apply({\"params\": policy_params}, obs)\n",
" rng, rng_step, rng_net = jax.random.split(rng, 3)\n",
" action = network.apply({\"params\": policy_params}, obs, rng_net)\n",
" next_o, next_s, reward, done, _ = env.step(\n",
" rng_step, state, action, env_params\n",
" )\n",
Expand All @@ -118,15 +119,54 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(-1600.6174, dtype=float32)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Jit-Compiled Episode Rollout\n",
"jit_rollout = jax.jit(rollout, static_argnums=3)\n",
"jit_rollout(rng, policy_params, env_params, 200)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([-1076.0648, -1299.151 , -1502.7104, -1210.023 , -1287.3069,\n",
" -1267.0559, -1529.9363, -1517.0801, -1809.3643, -1565.3792], dtype=float32)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from gymnax.experimental.rollout import EnvRollout\n",
"\n",
"roller = EnvRollout(model_forward=network.apply,\n",
" env_name=\"Pendulum-v1\",\n",
" num_env_steps=200,\n",
" num_episodes=10)\n",
"\n",
"roller.collect(rng, policy_params)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -168,7 +208,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Distributed Anakin Agent on BSuite Catch Environment\n",
"# Distributed Anakin Agent\n",
"\n",
"\n",
"Adapted from Hessel et al. (2021) and DeepMind's [Example Colab](https://colab.research.google.com/drive/1974D-qP17fd5mLxy6QZv-ic4yxlPJp-G?usp=sharing#scrollTo=lhnJkrYLOvcs)"
Expand Down Expand Up @@ -418,6 +458,13 @@
"run_experiment(env, 128, 16, 1e-4, 100, 42)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Vectorized Population Evaluation for CMA-ES"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -428,9 +475,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python3 (mle-toolbox)",
"display_name": "snippets",
"language": "python",
"name": "mle-toolbox"
"name": "snippets"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -442,7 +489,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
"version": "3.8.11"
}
},
"nbformat": 4,
Expand Down
File renamed without changes.
78 changes: 78 additions & 0 deletions gymnax/experimental/rollout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import jax
import jax.numpy as jnp
import gymnax


class EnvRollout(object):
def __init__(
self,
model_forward,
env_name: str = "Pendulum-v1",
num_env_steps: int = 200,
num_episodes: int = 20,
):
"""Wrapper to define batch evaluation for generation parameters."""
self.env_name = env_name
self.num_env_steps = num_env_steps
self.num_episodes = num_episodes

# Define the RL environment & network forward function
self.env, self.env_params = gymnax.make(self.env_name)
self.model_forward = model_forward

# Set up the generation evaluation vmap-ed function - rl/supervised/etc.
self.gen_evaluate = self.batch_evaluate()

def collect(self, rng_eval, policy_params):
"""Reshape parameter vector and evaluate the generation."""
# Reshape the parameters into the correct network format
rollout_keys = jax.random.split(rng_eval, self.num_episodes)

# Evaluate generation population on pendulum task - min cost!
pop_trajectories = self.gen_evaluate(rollout_keys, policy_params)
return pop_trajectories

def batch_evaluate(self):
"""Evaluate a generation of networks on RL/Supervised/etc. task."""
# vmap over different MC fitness evaluations for single network
batch_rollout = jax.jit(
jax.vmap(self.rollout, in_axes=(0, None), out_axes=0)
)
return batch_rollout

def rollout(self, rng_input, policy_params):
"""Rollout a pendulum episode with lax.scan."""
# Reset the environment
rng_reset, rng_episode = jax.random.split(rng_input)
obs, state = self.env.reset(rng_reset, self.env_params)

def policy_step(state_input, tmp):
"""lax.scan compatible step transition in jax env."""
obs, state, policy_params, rng = state_input
rng, rng_step, rng_net = jax.random.split(rng, 3)
action = self.model_forward({"params": policy_params}, obs, rng=rng_net)
next_o, next_s, reward, done, _ = self.env.step(
rng_step, state, action, self.env_params
)
carry = [next_o.squeeze(), next_s, policy_params, rng]
y = [next_o.squeeze(), reward, done]
return carry, y

# Scan over episode step loop
_, scan_out = jax.lax.scan(
policy_step,
[obs, state, policy_params, rng_episode],
[jnp.zeros((self.num_env_steps, self.input_shape[0] + 2))],
)
# Return the sum of rewards accumulated by agent in episode rollout
obs, rewards, dones = scan_out[0], scan_out[1], scan_out[2]
rewards = rewards.reshape(self.num_env_steps, 1)
ep_mask = (jnp.cumsum(dones) < 1).reshape(self.num_env_steps, 1)
return obs, rewards, dones, jnp.sum(rewards * ep_mask)

@property
def input_shape(self):
"""Get the shape of the observation."""
rng = jax.random.PRNGKey(0)
obs, state = self.env.reset(rng, self.env_params)
return obs.shape

0 comments on commit 6602c57

Please sign in to comment.