Skip to content

Commit

Permalink
work-in-progress
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Feb 16, 2025
1 parent 56beb38 commit a19f31b
Show file tree
Hide file tree
Showing 8 changed files with 870 additions and 1 deletion.
Binary file removed docs/images/stoix.png
Binary file not shown.
12 changes: 12 additions & 0 deletions stoix/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,18 @@ class OffPolicyLearnerState(NamedTuple):
timestep: TimeStep


class RNNOffPolicyLearnerState(NamedTuple):
params: Parameters
opt_states: OptStates
buffer_state: BufferState
key: chex.PRNGKey
env_state: LogEnvState
timestep: TimeStep
dones: Done
truncated: Truncated
hstates: HiddenStates


class OnlineAndTarget(NamedTuple):
online: FrozenDict
target: FrozenDict
Expand Down
2 changes: 1 addition & 1 deletion stoix/configs/arch/anakin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ architecture_name: anakin
# --- Training ---
seed: 42 # RNG seed.
update_batch_size: 1 # Number of vectorised gradient updates per device.
total_num_envs: 1024 # Total Number of vectorised environments across all devices and batched_updates. Needs to be divisible by n_devices*update_batch_size.
total_num_envs: 8 # Total Number of vectorised environments across all devices and batched_updates. Needs to be divisible by n_devices*update_batch_size.
total_timesteps: 1e7 # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: ~ # Number of updates
Expand Down
11 changes: 11 additions & 0 deletions stoix/configs/default/anakin/default_rec_r2d2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
defaults:
- logger: base_logger
- arch: anakin
- system: q_learning/rec_r2d2
- network: rnn_dqn
- env: gymnax/cartpole
- _self_

hydra:
searchpath:
- file://stoix/configs
19 changes: 19 additions & 0 deletions stoix/configs/network/rnn_dqn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# ---Recurrent Structure Networks for PPO ---

actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [128]
use_layer_norm: False
activation: silu
rnn_layer:
_target_: stoix.networks.base.ScannedRNN
cell_type: gru
hidden_state_dim: 128
post_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [128]
use_layer_norm: False
activation: silu
action_head:
_target_: stoix.networks.heads.DiscreteQNetworkHead
24 changes: 24 additions & 0 deletions stoix/configs/system/q_learning/rec_r2d2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# --- Defaults Rec-R2D2 ---

system_name: rec_r2d2 # Name of the system.

# --- RL hyperparameters ---
rollout_length: 4 # Number of environment steps per vectorised environment.
epochs: 128 # Number of sgd steps per rollout.
warmup_steps: 16 # Number of steps to collect before training.
total_buffer_size: 1_000_000 # Total effective size of the replay buffer across all devices and vectorised update steps. This means each device has a buffer of size buffer_size//num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
total_batch_size: 512 # Total effective number of samples to train on. This means each device has a batch size of batch_size/num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
burn_in_length: 40 # Number of steps to burn in before training.
sample_sequence_length: 80 # Length of the sequence to sample from the buffer.
priority_exponent: 0.5 # exponent for the prioritised experience replay
importance_sampling_exponent: 0.4 # exponent for the importance sampling weights
priority_eta: 0.9 # Balance between max and mean priorities
n_step: 5 # how many steps in the transition to use for the n-step return
q_lr: 6.25e-5 # the learning rate of the Q network network optimizer
tau: 0.005 # smoothing coefficient for target networks
gamma: 0.99 # discount factor
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
training_epsilon: 0.0 # epsilon for the epsilon-greedy policy during training
evaluation_epsilon: 0.0 # epsilon for the epsilon-greedy policy during evaluation
max_abs_reward: 1000.0 # maximum absolute reward value
13 changes: 13 additions & 0 deletions stoix/systems/q_learning/dqn_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import chex
from typing_extensions import NamedTuple

from stoix.base_types import HiddenState


class Transition(NamedTuple):
obs: chex.ArrayTree
Expand All @@ -11,3 +13,14 @@ class Transition(NamedTuple):
done: chex.Array
next_obs: chex.Array
info: Dict


class RNNTransition(NamedTuple):
obs: chex.ArrayTree
action: chex.Array
reward: chex.Array
done: chex.Array
truncated: chex.Array
next_obs: chex.Array
info: Dict
hstate: HiddenState
Loading

0 comments on commit a19f31b

Please sign in to comment.