Skip to content

Commit

Permalink
chore: slight chhanges to eval learner state device placement
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Nov 4, 2024
1 parent 0e04987 commit ef08c9a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
10 changes: 5 additions & 5 deletions stoix/configs/arch/sebulba.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@
architecture_name : sebulba
# --- Training ---
seed: 42 # RNG seed.
total_num_envs: 1024 # Total Number of vectorised environments across all actors. Needs to be divisible by the number of actor devices and actors per device.
total_timesteps: 1e7 # Set the total environment steps.
total_num_envs: 128 # Total Number of vectorised environments across all actors. Needs to be divisible by the number of actor devices and actors per device.
total_timesteps: 1e5 # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: ~ # Number of updates

# Define the number of actors per device and which devices to use.
actor:
device_ids: [0,1] # Define which devices to use for the actors.
actor_per_device: 2 # number of different threads per actor device.
device_ids: [0] # Define which devices to use for the actors.
actor_per_device: 1 # number of different threads per actor device.

# Define which devices to use for the learner.
learner:
device_ids: [2,3] # Define which devices to use for the learner.
device_ids: [1] # Define which devices to use for the learner.

# Size of the queue for the pipeline where actors push data and the learner pulls data.
pipeline_queue_size: 10
Expand Down
8 changes: 4 additions & 4 deletions stoix/systems/ppo/sebulba/ff_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,9 +809,9 @@ def run_experiment(_config: DictConfig) -> float:
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Evaluate the current model and log the metrics
learner_state_cpu = jax.device_get(learner_state)
eval_learner_state = jax.device_put(learner_state, evaluator_device)
key, eval_key = jax.random.split(key, 2)
eval_metrics = evaluator(learner_state_cpu.params.actor_params, eval_key)
eval_metrics = evaluator(eval_learner_state.params.actor_params, eval_key)
logger.log(eval_metrics, t, eval_step, LogEvent.EVAL)

episode_return = jnp.mean(eval_metrics["episode_return"])
Expand All @@ -820,12 +820,12 @@ def run_experiment(_config: DictConfig) -> float:
# Save checkpoint of learner state
checkpointer.save(
timestep=steps_consumed_per_eval * (eval_step + 1),
unreplicated_learner_state=learner_state_cpu,
unreplicated_learner_state=jax.device_get(learner_state),
episode_return=episode_return,
)

if config.arch.absolute_metric and max_episode_return <= episode_return:
best_params = copy.deepcopy(learner_state_cpu.params.actor_params)
best_params = copy.deepcopy(eval_learner_state.params.actor_params)
max_episode_return = episode_return

evaluator_envs.close()
Expand Down

0 comments on commit ef08c9a

Please sign in to comment.