diff --git a/brax/training/agents/apg/train.py b/brax/training/agents/apg/train.py index 0c6ad2fc..c635fbdf 100644 --- a/brax/training/agents/apg/train.py +++ b/brax/training/agents/apg/train.py @@ -174,7 +174,6 @@ def loss(policy_params, normalizer_params, env_state, key): (state_h, _), (rewards, obs) = jax.lax.scan(f, (env_state, key), (jnp.arange(horizon_length // action_repeat))) - jax.debug.print("{x}", x=jnp.mean(rewards)) return -jnp.mean(rewards), (obs, state_h) @@ -245,7 +244,7 @@ def training_epoch_with_timing(training_state: TrainingState, epoch_training_time = time.time() - t training_walltime += epoch_training_time - sps = (episode_length * num_envs) / epoch_training_time + sps = (updates_per_epoch * num_envs * horizon_length) / epoch_training_time metrics = { 'training/sps': sps, 'training/walltime': training_walltime, @@ -292,13 +291,13 @@ def training_epoch_with_timing(training_state: TrainingState, # Run initial eval metrics = {} - # if process_id == 0 and num_evals > 1: - # metrics = evaluator.run_evaluation( - # _unpmap( - # (training_state.normalizer_params, training_state.policy_params)), - # training_metrics={}) - # logging.info(metrics) - # progress_fn(0, metrics) + if process_id == 0 and num_evals > 1: + metrics = evaluator.run_evaluation( + _unpmap( + (training_state.normalizer_params, training_state.policy_params)), + training_metrics={}) + logging.info(metrics) + progress_fn(0, metrics) init_key, scramble_key, local_key = jax.random.split(local_key, 3) init_key = jax.random.split(init_key, (local_devices_to_use, num_envs // process_count)) @@ -313,29 +312,17 @@ def training_epoch_with_timing(training_state: TrainingState, for it in range(num_evals_after_init): logging.info('starting iteration %s %s', it, time.time() - xt) - # optimization - from pathlib import Path - import pickle - - file_name = f'checkpoint_{it}.pkl' - base_path = "/tmp/checkpoints/" - save_to = str(Path( - Path(base_path), - Path(file_name))) - algo_state = {'training_state': training_state} - pickle.dump(algo_state, open(save_to, "wb")) - (training_state, env_state, training_metrics, epoch_key) = training_epoch_with_timing(training_state, env_state, epoch_key) - # if process_id == 0: - # # Run evals. - # metrics = evaluator.run_evaluation( - # _unpmap( - # (training_state.normalizer_params, training_state.policy_params)), - # training_metrics) - # logging.info(metrics) - # progress_fn(it + 1, metrics) + if process_id == 0: + # Run evals. + metrics = evaluator.run_evaluation( + _unpmap( + (training_state.normalizer_params, training_state.policy_params)), + training_metrics) + logging.info(metrics) + progress_fn(it + 1, metrics) # If there was no mistakes the training_state should still be identical on all # devices.