diff --git a/init2winit/main.py b/init2winit/main.py index ebe044c7..bcb19d45 100644 --- a/init2winit/main.py +++ b/init2winit/main.py @@ -74,6 +74,7 @@ 'make from tf.data.map. Set to -1 for tf.data.AUTOTUNE.' ) flags.DEFINE_integer('eval_batch_size', None, 'Batch size for evaluation.') +flags.DEFINE_bool('eval_use_ema', None, 'If True evals will use ema of params.') flags.DEFINE_integer( 'eval_num_batches', None, 'Number of batches for evaluation. Leave None to evaluate ' @@ -169,6 +170,7 @@ def _run( dataset_name, data_selector_name, eval_batch_size, + eval_use_ema, eval_num_batches, test_num_batches, eval_train_num_batches, @@ -256,6 +258,7 @@ def _run( merged_hps, rng, eval_batch_size, + eval_use_ema, eval_num_batches, test_num_batches, eval_train_num_batches, @@ -325,6 +328,7 @@ def main(unused_argv): dataset_name=FLAGS.dataset, data_selector_name=FLAGS.data_selector, eval_batch_size=FLAGS.eval_batch_size, + eval_use_ema=FLAGS.eval_use_ema, eval_num_batches=FLAGS.eval_num_batches, test_num_batches=FLAGS.test_num_batches, eval_train_num_batches=FLAGS.eval_train_num_batches, diff --git a/init2winit/optimizer_lib/kitchen_sink/_src/transform.py b/init2winit/optimizer_lib/kitchen_sink/_src/transform.py index 8a32f0ba..0d868144 100644 --- a/init2winit/optimizer_lib/kitchen_sink/_src/transform.py +++ b/init2winit/optimizer_lib/kitchen_sink/_src/transform.py @@ -157,6 +157,59 @@ def polyak_hb( ) +def compute_params_ema_for_eval( + decay: float, warmup: bool = False +) -> optax.GradientTransformation: + """Applies exponential moving average on weights. + + Note, this implementation averages the weight before optimization because + trainable and non-trainable variables are handled separately. In such case + the updates on non-trainable variables like bn stats are not available in + updates. + + This differs from optax.ema which applies ema on gradients so it changes + training process. + + ema = ema * decay + new_weight * (1.0 - decay) + + Args: + decay: A float number represents the weight on the moving average. + warmup: bool controlling if we ignore initial training steps for EMA. + + Returns: + A GradientTransformation applying ema. + """ + + def init_fn(params): + return optax.EmaState( + count=jnp.array(0, dtype=jnp.int32), ema=jax.tree_map(jnp.copy, params)) + + def update_fn(updates, state, params): + if params is None: + raise ValueError('Params required for the EMA') + + if warmup: + # https://github.com/tensorflow/tensorflow/blob/v2.9.1/tensorflow/python/training/moving_averages.py#L469 + ema_decay = jnp.minimum(decay, (1. + state.count) / (10. + state.count)) + else: + ema_decay = decay + + def update_func(old_v, new_v): + if old_v.dtype == jnp.bool_ or jnp.issubdtype(old_v, jnp.integer): + # If it is integer, we directly return the new variable + # This is mainly supported for non_trainable + return new_v + else: + return old_v - (1.0 - ema_decay) * (old_v - new_v) + + new_ema = jax.tree_map(update_func, state.ema, params) + count_inc = state.count + jnp.array(1, jnp.int32) + + return updates, optax.EmaState(count=count_inc, ema=new_ema) + + return optax.GradientTransformation(init_fn, update_fn) + + def first_moment_ema( decay: float = 0.9, debias: bool = False, @@ -1706,6 +1759,7 @@ def update_fn(updates, state, params): 'ema_nesterov': ema_nesterov, 'polyak_hb': polyak_hb, 'first_moment_ema': first_moment_ema, + 'compute_params_ema_for_eval': compute_params_ema_for_eval, 'normalized_first_moment_ema': normalized_first_moment_ema, 'nesterovpp': nesterovpp, } diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index 555db8c1..132ed7ad 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -31,6 +31,7 @@ from init2winit.training_metrics_grabber import make_training_metrics import jax import numpy as np +import optax import orbax.checkpoint as orbax_checkpoint @@ -47,6 +48,7 @@ def __init__( hps, rng, eval_batch_size, + eval_use_ema, eval_num_batches, test_num_batches, eval_train_num_batches, @@ -85,6 +87,7 @@ def __init__( rng: (jax.random.PRNGKey) Rng seed used in model initialization and data shuffling. eval_batch_size: the evaluation batch size. If None, use hps.batch_size. + eval_use_ema: if True evals will use ema of params. eval_num_batches: (int) The number of batches used for evaluating on validation sets. Set to None to evaluate on the whole eval set. test_num_batches: (int) The number of batches used for evaluating on test @@ -146,8 +149,10 @@ def __init__( self._hps = hps self._rng = rng eval_batch_size = ( - self._hps.batch_size if eval_batch_size is None else eval_batch_size) + self._hps.batch_size if eval_batch_size is None else eval_batch_size + ) self._eval_batch_size = eval_batch_size + self._eval_use_ema = eval_use_ema self._eval_num_batches = eval_num_batches self._test_num_batches = test_num_batches self._eval_train_num_batches = eval_train_num_batches @@ -534,12 +539,28 @@ def _eval( Returns: A Dict[str, Any] eval report, originally created in trainer_utils.eval_metrics. + """ time_since_last_eval = time.time() - self._time_at_prev_eval_end self._batch_stats = trainer_utils.maybe_sync_batchnorm_stats( - self._batch_stats) + self._batch_stats + ) + + if self._eval_use_ema: + if isinstance( + self._optimizer_state.base_state.inner_state[0][0], optax.EmaState + ): + eval_params = self._optimizer_state.base_state.inner_state[0][0].ema + else: + raise ValueError( + 'EMA computation should be the very first transformation in defined' + ' kitchensink optimizer.' + ) + else: + eval_params = self._params + report, eval_time = trainer_utils.eval_metrics( - self._params, + eval_params, self._batch_stats, self._dataset, self._eval_num_batches,