Skip to content

Commit

Permalink
clean up, fix the sps calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew-Luo1 committed Apr 15, 2024
1 parent 7e67df2 commit e768157
Showing 1 changed file with 16 additions and 29 deletions.
45 changes: 16 additions & 29 deletions brax/training/agents/apg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def loss(policy_params, normalizer_params, env_state, key):
(state_h, _), (rewards,
obs) = jax.lax.scan(f, (env_state, key),
(jnp.arange(horizon_length // action_repeat)))
jax.debug.print("{x}", x=jnp.mean(rewards))

return -jnp.mean(rewards), (obs, state_h)

Expand Down Expand Up @@ -245,7 +244,7 @@ def training_epoch_with_timing(training_state: TrainingState,

epoch_training_time = time.time() - t
training_walltime += epoch_training_time
sps = (episode_length * num_envs) / epoch_training_time
sps = (updates_per_epoch * num_envs * horizon_length) / epoch_training_time
metrics = {
'training/sps': sps,
'training/walltime': training_walltime,
Expand Down Expand Up @@ -292,13 +291,13 @@ def training_epoch_with_timing(training_state: TrainingState,
# Run initial eval
metrics = {}

# if process_id == 0 and num_evals > 1:
# metrics = evaluator.run_evaluation(
# _unpmap(
# (training_state.normalizer_params, training_state.policy_params)),
# training_metrics={})
# logging.info(metrics)
# progress_fn(0, metrics)
if process_id == 0 and num_evals > 1:
metrics = evaluator.run_evaluation(
_unpmap(
(training_state.normalizer_params, training_state.policy_params)),
training_metrics={})
logging.info(metrics)
progress_fn(0, metrics)

init_key, scramble_key, local_key = jax.random.split(local_key, 3)
init_key = jax.random.split(init_key, (local_devices_to_use, num_envs // process_count))
Expand All @@ -313,29 +312,17 @@ def training_epoch_with_timing(training_state: TrainingState,
for it in range(num_evals_after_init):
logging.info('starting iteration %s %s', it, time.time() - xt)

# optimization
from pathlib import Path
import pickle

file_name = f'checkpoint_{it}.pkl'
base_path = "/tmp/checkpoints/"
save_to = str(Path(
Path(base_path),
Path(file_name)))
algo_state = {'training_state': training_state}
pickle.dump(algo_state, open(save_to, "wb"))

(training_state, env_state,
training_metrics, epoch_key) = training_epoch_with_timing(training_state, env_state, epoch_key)

# if process_id == 0:
# # Run evals.
# metrics = evaluator.run_evaluation(
# _unpmap(
# (training_state.normalizer_params, training_state.policy_params)),
# training_metrics)
# logging.info(metrics)
# progress_fn(it + 1, metrics)
if process_id == 0:
# Run evals.
metrics = evaluator.run_evaluation(
_unpmap(
(training_state.normalizer_params, training_state.policy_params)),
training_metrics)
logging.info(metrics)
progress_fn(it + 1, metrics)

# If there was no mistakes the training_state should still be identical on all
# devices.
Expand Down

0 comments on commit e768157

Please sign in to comment.