Skip to content

Commit

Permalink
Internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 544664653
  • Loading branch information
sourabh2k15 authored and copybara-github committed Sep 28, 2023
1 parent 7f497a0 commit 751a51d
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 3 deletions.
4 changes: 4 additions & 0 deletions init2winit/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
'make from tf.data.map. Set to -1 for tf.data.AUTOTUNE.'
)
flags.DEFINE_integer('eval_batch_size', None, 'Batch size for evaluation.')
flags.DEFINE_bool('eval_use_ema', None, 'If True evals will use ema of params.')
flags.DEFINE_integer(
'eval_num_batches', None,
'Number of batches for evaluation. Leave None to evaluate '
Expand Down Expand Up @@ -169,6 +170,7 @@ def _run(
dataset_name,
data_selector_name,
eval_batch_size,
eval_use_ema,
eval_num_batches,
test_num_batches,
eval_train_num_batches,
Expand Down Expand Up @@ -256,6 +258,7 @@ def _run(
merged_hps,
rng,
eval_batch_size,
eval_use_ema,
eval_num_batches,
test_num_batches,
eval_train_num_batches,
Expand Down Expand Up @@ -325,6 +328,7 @@ def main(unused_argv):
dataset_name=FLAGS.dataset,
data_selector_name=FLAGS.data_selector,
eval_batch_size=FLAGS.eval_batch_size,
eval_use_ema=FLAGS.eval_use_ema,
eval_num_batches=FLAGS.eval_num_batches,
test_num_batches=FLAGS.test_num_batches,
eval_train_num_batches=FLAGS.eval_train_num_batches,
Expand Down
54 changes: 54 additions & 0 deletions init2winit/optimizer_lib/kitchen_sink/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,59 @@ def polyak_hb(
)


def compute_params_ema_for_eval(
decay: float, warmup: bool = False
) -> optax.GradientTransformation:
"""Applies exponential moving average on weights.
Note, this implementation averages the weight before optimization because
trainable and non-trainable variables are handled separately. In such case
the updates on non-trainable variables like bn stats are not available in
updates.
This differs from optax.ema which applies ema on gradients so it changes
training process.
ema = ema * decay + new_weight * (1.0 - decay)
Args:
decay: A float number represents the weight on the moving average.
warmup: bool controlling if we ignore initial training steps for EMA.
Returns:
A GradientTransformation applying ema.
"""

def init_fn(params):
return optax.EmaState(
count=jnp.array(0, dtype=jnp.int32), ema=jax.tree_map(jnp.copy, params))

def update_fn(updates, state, params):
if params is None:
raise ValueError('Params required for the EMA')

if warmup:
# https://github.com/tensorflow/tensorflow/blob/v2.9.1/tensorflow/python/training/moving_averages.py#L469
ema_decay = jnp.minimum(decay, (1. + state.count) / (10. + state.count))
else:
ema_decay = decay

def update_func(old_v, new_v):
if old_v.dtype == jnp.bool_ or jnp.issubdtype(old_v, jnp.integer):
# If it is integer, we directly return the new variable
# This is mainly supported for non_trainable
return new_v
else:
return old_v - (1.0 - ema_decay) * (old_v - new_v)

new_ema = jax.tree_map(update_func, state.ema, params)
count_inc = state.count + jnp.array(1, jnp.int32)

return updates, optax.EmaState(count=count_inc, ema=new_ema)

return optax.GradientTransformation(init_fn, update_fn)


def first_moment_ema(
decay: float = 0.9,
debias: bool = False,
Expand Down Expand Up @@ -1706,6 +1759,7 @@ def update_fn(updates, state, params):
'ema_nesterov': ema_nesterov,
'polyak_hb': polyak_hb,
'first_moment_ema': first_moment_ema,
'compute_params_ema_for_eval': compute_params_ema_for_eval,
'normalized_first_moment_ema': normalized_first_moment_ema,
'nesterovpp': nesterovpp,
}
Expand Down
27 changes: 24 additions & 3 deletions init2winit/trainer_lib/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from init2winit.training_metrics_grabber import make_training_metrics
import jax
import numpy as np
import optax
import orbax.checkpoint as orbax_checkpoint


Expand All @@ -47,6 +48,7 @@ def __init__(
hps,
rng,
eval_batch_size,
eval_use_ema,
eval_num_batches,
test_num_batches,
eval_train_num_batches,
Expand Down Expand Up @@ -85,6 +87,7 @@ def __init__(
rng: (jax.random.PRNGKey) Rng seed used in model initialization and data
shuffling.
eval_batch_size: the evaluation batch size. If None, use hps.batch_size.
eval_use_ema: if True evals will use ema of params.
eval_num_batches: (int) The number of batches used for evaluating on
validation sets. Set to None to evaluate on the whole eval set.
test_num_batches: (int) The number of batches used for evaluating on test
Expand Down Expand Up @@ -146,8 +149,10 @@ def __init__(
self._hps = hps
self._rng = rng
eval_batch_size = (
self._hps.batch_size if eval_batch_size is None else eval_batch_size)
self._hps.batch_size if eval_batch_size is None else eval_batch_size
)
self._eval_batch_size = eval_batch_size
self._eval_use_ema = eval_use_ema
self._eval_num_batches = eval_num_batches
self._test_num_batches = test_num_batches
self._eval_train_num_batches = eval_train_num_batches
Expand Down Expand Up @@ -534,12 +539,28 @@ def _eval(
Returns:
A Dict[str, Any] eval report, originally created in
trainer_utils.eval_metrics.
"""
time_since_last_eval = time.time() - self._time_at_prev_eval_end
self._batch_stats = trainer_utils.maybe_sync_batchnorm_stats(
self._batch_stats)
self._batch_stats
)

if self._eval_use_ema:
if isinstance(
self._optimizer_state.base_state.inner_state[0][0], optax.EmaState
):
eval_params = self._optimizer_state.base_state.inner_state[0][0].ema
else:
raise ValueError(
'EMA computation should be the very first transformation in defined'
' kitchensink optimizer.'
)
else:
eval_params = self._params

report, eval_time = trainer_utils.eval_metrics(
self._params,
eval_params,
self._batch_stats,
self._dataset,
self._eval_num_batches,
Expand Down

0 comments on commit 751a51d

Please sign in to comment.