Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add alphazero-like agent #5

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions stoix/configs/default_ff_az.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- logger: ff_az
- arch: anakin
- system: ff_az
- network: mlp
- env: gymnax/cartpole
- _self_
4 changes: 4 additions & 0 deletions stoix/configs/logger/ff_az.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
defaults:
- base_logger

system_name: ff_az
23 changes: 23 additions & 0 deletions stoix/configs/system/ff_az.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# --- Defaults FF-PPO ---

total_timesteps: 1e8 # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: ~ # Number of updates
seed: 42

# --- RL hyperparameters ---
actor_lr: 3e-4 # Learning rate for actor network
critic_lr: 3e-4 # Learning rate for critic network
update_batch_size: 1 # Number of vectorised gradient updates per device.
rollout_length: 16 # Number of environment steps per vectorised environment.
epochs: 1 # Number of epochs per training data batch.
num_minibatches: 8 # Number of minibatches per epoch.
gamma: 0.99 # Discounting factor.
gae_lambda: 0.95 # Lambda value for GAE computation.
ent_coef: 0.001 # Entropy regularisation term for loss function.
vf_coef: 1.0 # Critic weight in
clip_eps: 0.2 # Clipping value for value function updates.
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
num_simulations: 10 # Number of simulations to run.
max_depth: ~ # Maximum depth of the search tree.
136 changes: 136 additions & 0 deletions stoix/systems/search/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import Callable, Dict, Tuple

import chex
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from jumanji.env import Environment
from omegaconf import DictConfig

from stoix.systems.search.types import SearchApply
from stoix.types import EvalFn, EvalState, ExperimentOutput
from stoix.utils.jax import unreplicate_batch_dim


def get_search_evaluator_fn(
env: Environment,
search_apply_fn: SearchApply,
root_fn: Callable,
config: DictConfig,
eval_multiplier: int = 1,
) -> EvalFn:
"""Get the evaluator function for search-based agents."""

def eval_one_episode(params: FrozenDict, init_eval_state: EvalState) -> Dict:
"""Evaluate one episode. It is vectorized over the number of evaluation episodes."""

def _env_step(eval_state: EvalState) -> EvalState:
"""Step the environment."""
# PRNG keys.
key, env_state, last_timestep, step_count, episode_return = eval_state

# Select action.
key, policy_key = jax.random.split(key)
obs, model_env_state = jax.tree_map(
lambda x: x[jnp.newaxis, ...], (last_timestep.observation, env_state)
)
root = root_fn(params, obs, model_env_state)
search_output = search_apply_fn(params, policy_key, root)
action = search_output.action

# Step environment.
env_state, timestep = env.step(env_state, action.squeeze())

# Log episode metrics.
episode_return += timestep.reward
step_count += 1
eval_state = EvalState(key, env_state, timestep, step_count, episode_return)
return eval_state

def not_done(carry: Tuple) -> bool:
"""Check if the episode is done."""
timestep = carry[2]
is_not_done: bool = ~timestep.last()
return is_not_done

final_state = jax.lax.while_loop(not_done, _env_step, init_eval_state)

eval_metrics = {
"episode_return": final_state.episode_return,
"episode_length": final_state.step_count,
}

return eval_metrics

def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> ExperimentOutput[EvalState]:
"""Evaluator function."""

# Initialise environment states and timesteps.
n_devices = len(jax.devices())

eval_batch = (config.arch.num_eval_episodes // n_devices) * eval_multiplier

key, *env_keys = jax.random.split(key, eval_batch + 1)
env_states, timesteps = jax.vmap(env.reset)(
jnp.stack(env_keys),
)
# Split keys for each core.
key, *step_keys = jax.random.split(key, eval_batch + 1)
# Add dimension to pmap over.
step_keys = jnp.stack(step_keys).reshape(eval_batch, -1)

eval_state = EvalState(
key=step_keys,
env_state=env_states,
timestep=timesteps,
step_count=jnp.zeros((eval_batch, 1)),
episode_return=jnp.zeros_like(timesteps.reward),
)

eval_metrics = jax.vmap(
eval_one_episode,
in_axes=(None, 0),
axis_name="eval_batch",
)(trained_params, eval_state)

return ExperimentOutput(
learner_state=eval_state,
episode_metrics=eval_metrics,
train_metrics={},
)

return evaluator_fn


def search_evaluator_setup(
eval_env: Environment,
key_e: chex.PRNGKey,
search_apply_fn: SearchApply,
root_fn: Callable,
params: FrozenDict,
config: DictConfig,
) -> Tuple[EvalFn, EvalFn, Tuple[FrozenDict, chex.Array]]:
"""Initialise evaluator_fn."""
# Get available TPU cores.
n_devices = len(jax.devices())
# Check if win rate is required for evaluation.

eval_apply_fn = search_apply_fn
evaluator = get_search_evaluator_fn(eval_env, eval_apply_fn, root_fn, config)
absolute_metric_evaluator = get_search_evaluator_fn(
eval_env,
eval_apply_fn,
root_fn,
config,
10,
)

evaluator = jax.pmap(evaluator, axis_name="device")
absolute_metric_evaluator = jax.pmap(absolute_metric_evaluator, axis_name="device")

# Broadcast trained params to cores and split keys for each core.
trained_params = unreplicate_batch_dim(params)
key_e, *eval_keys = jax.random.split(key_e, n_devices + 1)
eval_keys = jnp.stack(eval_keys).reshape(n_devices, -1)

return evaluator, absolute_metric_evaluator, (trained_params, eval_keys)
Loading
Loading