diff --git a/lwm/train.py b/lwm/train.py index 1e1fe28..c03d09a 100644 --- a/lwm/train.py +++ b/lwm/train.py @@ -365,7 +365,7 @@ def save_checkpoint(train_state, milestone=False): train_state, sharded_rng, batch ) if step % FLAGS.log_freq == 0: - if FLAGS.eval_steps > 0: + if FLAGS.eval_steps > 0 and step % 10 == 0: eval_metric_list = [] for _ in range(FLAGS.eval_steps): eval_batch, _ = next(eval_iterator) @@ -376,11 +376,10 @@ def save_checkpoint(train_state, milestone=False): eval_metric_list.append(eval_metrics) metrics.update(average_metrics(eval_metric_list)) - log_metrics = {"step": step} - log_metrics.update(metrics) + log_metrics = metrics log_metrics.update(dataset_metrics) log_metrics = jax.device_get(log_metrics) - logger.log(log_metrics) + logger.log(log_metrics, step=step) tqdm.write("\n" + pprint.pformat(log_metrics) + "\n") if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0: