From 41b3b9c7a9e9bb80d9f7910316e89a126f3a165a Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 19:24:28 -0600 Subject: [PATCH 1/2] added eve optimizer (for adam) --- docs/api.rst | 10 +++ optax/__init__.py | 8 ++- optax/_src/alias.py | 124 +++++++++++++++++++++++++++++++++++ optax/_src/alias_test.py | 8 +++ optax/_src/transform.py | 71 ++++++++++++++++++++ optax/_src/transform_test.py | 1 + 6 files changed, 221 insertions(+), 1 deletion(-) diff --git a/docs/api.rst b/docs/api.rst index a3c09bb87..a0821bc67 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -13,6 +13,7 @@ Common Optimizers adamax adamaxw amsgrad + eve fromage lamb lars @@ -67,6 +68,11 @@ AMSGrad .. autofunction:: amsgrad +Eve +~~~ + +.. autofunction:: eve + Fromage ~~~~~~~ @@ -289,6 +295,7 @@ Optax Transforms and States .. autofunction:: scale_by_adamax .. autofunction:: scale_by_amsgrad .. autofunction:: scale_by_belief +.. autofunction:: scale_by_eve .. autofunction:: scale_by_factored_rms .. autofunction:: scale_by_novograd .. autofunction:: scale_by_param_block_norm @@ -310,6 +317,9 @@ Optax Transforms and States .. autoclass:: ScaleByNovogradState :members: +.. autoclass:: ScaleByEveState + :members: + .. autoclass:: ScaleByRmsState :members: diff --git a/optax/__init__.py b/optax/__init__.py index 278255dc5..a76aede16 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -24,6 +24,7 @@ from optax._src.alias import adamw from optax._src.alias import amsgrad from optax._src.alias import dpsgd +from optax._src.alias import eve from optax._src.alias import fromage from optax._src.alias import lamb from optax._src.alias import lars @@ -130,6 +131,7 @@ from optax._src.transform import scale_by_adamax from optax._src.transform import scale_by_amsgrad from optax._src.transform import scale_by_belief +from optax._src.transform import scale_by_eve from optax._src.transform import scale_by_novograd from optax._src.transform import scale_by_optimistic_gradient from optax._src.transform import scale_by_param_block_norm @@ -145,6 +147,7 @@ from optax._src.transform import ScaleByAdamState from optax._src.transform import ScaleByAmsgradState from optax._src.transform import ScaleByBeliefState +from optax._src.transform import ScaleByEveState from optax._src.transform import ScaleByNovogradState from optax._src.transform import ScaleByRmsState from optax._src.transform import ScaleByRssState @@ -177,7 +180,7 @@ from optax._src.wrappers import skip_large_updates from optax._src.wrappers import skip_not_finite -__version__ = "0.1.5.dev" +__version__ = "0.1.5.dev0" __all__ = ( "adabelief", @@ -223,6 +226,7 @@ "ema", "EmaState", "EmptyState", + "eve", "exponential_decay", "FactoredState", "fisher_diag", @@ -284,6 +288,7 @@ "scale_by_adamax", "scale_by_amsgrad", "scale_by_belief", + "scale_by_eve", "scale_by_factored_rms", "scale_by_novograd", "scale_by_param_block_norm", @@ -301,6 +306,7 @@ "ScaleByAdamState", "ScaleByAmsgradState", "ScaleByBeliefState", + "ScaleByEveState", "ScaleByNovogradState", "ScaleByRmsState", "ScaleByRssState", diff --git a/optax/_src/alias.py b/optax/_src/alias.py index c6ae6b602..682c9af1a 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -23,6 +23,7 @@ from optax._src import combine from optax._src import factorized from optax._src import privacy +from optax._src import schedule from optax._src import transform from optax._src import wrappers @@ -339,6 +340,129 @@ def amsgrad( _scale_by_learning_rate(learning_rate), ) +def _eve( + a1: float = 1e-3, + b1: float = 0.9, + b2: float = 0.999, + b3: float = 0.999, + c: float = 10., + eps: float = 1e-8, + f: float = 1., + f_star: float = 0., + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """The Eve optimizer (uninjectable, see `eve()`). + + Eve is an SGD variant with adaptive global and local learning rates. + The local learning rate used for each weight is computed from estimates of + first- and second-order moments of the gradients (using suitable exponential + moving averages) as in ADAM. These are then scaled by the global learning + rate `a1`, which is adaptively modified by some notion of sub-optimality `d`: + increasing the global rate when far from optimal and decreasing it when + approaching optimality. This is also computed with exponential moving + averages, similar to the first and second moments. + + References: + Hayashi et al, 2018: https://arXiv.org/abs/1611.01505 + + Args: + a1: this is the initial global scaling factor. + b1: the exponential decay rate to track the first moment of past gradients. + b2: the exponential decay rate to track the second moment of past gradients. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes + eps: a small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + f: the current loss value. (needs to be injected before update is called) + f_star: estimation of the global minimum + mu_dtype: optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + the corresponding `GradientTransformation` + + Note: + Eve requires an additional parameter: the loss for the current iteration:: + + f := f_t + + ScaleByEveState also holds the loss from the previous iteration:: + + state.f_prev := f_{t-1} + + Since it is up to the user to inject the current loss before calling the + update function, the `eve` alias returns an injectable state by default by + wrapping `_eve` in `inject_hyperparams`. + """ + return combine.chain( + transform.scale_by_eve( + b1=b1, b2=b2, b3=b3, c=c, eps=eps, f=f, f_star=f_star, mu_dtype=mu_dtype), + _scale_by_learning_rate(a1) + ) + + +def eve( + a1: float = 1e-3, + b1: float = 0.9, + b2: float = 0.999, + b3: float = 0.999, + c: float = 10., + eps: float = 1e-8, + f: float = 1., + f_star: float = 0., + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """Injectable Eve optimizer. + + Eve requires an additional parameter: the loss for the current iteration:: + + f := f_t + + ScaleByEveState also holds the loss from the previous iteration:: + + state.f_prev := f_{t-1} + + Since it is up to the user to inject the current loss before calling the + update function, the `eve` alias returns an injectable state by default by + wrapping `_eve` in `inject_hyperparams`. + + Args: + a1: this is the initial global scaling factor. + b1: the exponential decay rate to track the first moment of past gradients. + b2: the exponential decay rate to track the second moment of past gradients. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes + eps: a small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + f: the current loss value. (needs to be injected before update is called) + f_star: estimation of the global minimum + mu_dtype: optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + the corresponding `GradientTransformation` wrapped in inject_hyperparams + + Inject the current loss as follows: + ----------------------------------- + + Initialize:: + + optimizer = optax.eve() + opt_state = optimizer.init(params) + + Train:: + + while training: + loss, grads = jax.value_and_grad(loss_fn)(params, data) + opt_state.hyperparams['f'] = loss # <-- Update state here + updates, opt_state = optimizer.update(grads, opt_state) + params = optax.apply_updates(params, updates) + """ + return schedule.inject_hyperparams(_eve)( + a1=a1, b1=b1, b2=b2, b3=b3, c=c, eps=eps, + f=f, f_star=f_star, mu_dtype=mu_dtype + ) + def fromage( learning_rate: float, diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index be1a68b30..f8986e5a8 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -35,6 +35,7 @@ dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1e-1)), dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1e-1)), dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name='eve', opt_kwargs=dict(f=10)), dict(opt_name='lars', opt_kwargs=dict(learning_rate=1.0)), dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1e-3)), dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1e-3, eta=1e-4)), @@ -116,6 +117,9 @@ def step(params, state): updates = get_updates(params) if opt_name == 'dpsgd': updates = updates[None] + elif opt_name == 'eve': + f = jnp.mean(jnp.square(params-final_params)) + state.hyperparams['f'] = f # Complex gradients need to be conjugated before being added to parameters # https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 updates = jax.tree_util.tree_map(lambda x: x.conj(), updates) @@ -144,6 +148,10 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams( # https://github.com/deepmind/optax/issues/412. opt_inject = schedule.inject_hyperparams( opt_factory, static_args=('min_dim_size_to_factor',))(**opt_kwargs) + elif opt_name == 'eve': + # Eve is injectable by default. Reassign opt to uninjectable _eve alias + opt = alias._eve(**opt_kwargs) + opt_inject = opt_factory(**opt_kwargs) else: opt_inject = schedule.inject_hyperparams(opt_factory)(**opt_kwargs) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 2bbc75e92..1d6f7acff 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -449,6 +449,77 @@ def update_fn(updates, state, params=None): return base.GradientTransformation(init_fn, update_fn) +class ScaleByEveState(NamedTuple): + """State for the Eve algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: base.Updates + nu: base.Updates + d: float + f_prev: float + + +def scale_by_eve( + b1: float = 0.9, + b2: float = 0.999, + b3: float = 0.999, + c: float = 10., + eps: float = 1e-8, + f: float = 1., + f_star: float = 0., + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """Rescale updates according to the Eve algorithm. + + References: + [Hayashi et al, 2018](https://arxiv.org/abs/1611.01505) + + Args: + b1: the exponential decay rate to track the first moment of past gradients. + b2: the exponential decay rate to track the second moment of past gradients. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes + eps: a small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + f: the current loss value. (needs to be injected before update is called) + f_star: estimation of the global minimum + mu_dtype: optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + An (init_fn, update_fn) tuple. + """ + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = jax.tree_util.tree_map( # First moment + lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) + nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByEveState( + count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, d=1., f_prev=10. + ) + + + def update_fn(updates: base.Updates, state: ScaleByEveState, params=None): + del params + mu = update_moment(updates, state.mu, b1, 1) + nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) + count_inc = utils.numerics.safe_int32_increment(state.count) + mu_hat = jax.tree_util.tree_map(lambda m: m / (1-b1), mu) + nu_hat = jax.tree_util.tree_map(lambda v: v / (1-b2), nu) + d_new = jnp.abs(f-state.f_prev) /\ + (jnp.min(jnp.array([f,state.f_prev]))-f_star) + d_tilde = jnp.clip(d_new,1/c,c) + d = jnp.where(count_inc > 1, b3*state.d + (1-b3)*d_tilde, 1.) + updates = jax.tree_util.tree_map( + lambda m, v: m / (jnp.sqrt(v) + eps) / d, mu_hat, nu_hat) + mu = utils.cast_tree(mu, mu_dtype) + return updates, ScaleByEveState( + count=count_inc, mu=mu, nu=nu, d=d, f_prev=f + ) + + return base.GradientTransformation(init_fn, update_fn) + + ScaleState = base.EmptyState diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index 2c4ea9482..628db3ac8 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -44,6 +44,7 @@ def setUp(self): @parameterized.named_parameters([ ('adam', transform.scale_by_adam), ('adamax', transform.scale_by_adamax), + ('eve', transform.scale_by_eve), ('rmsprop', transform.scale_by_rms), ('stddev', transform.scale_by_stddev), ('trust_ratio', transform.scale_by_trust_ratio), From 308c6614f91a1bf17b29f224a8165408003db2a8 Mon Sep 17 00:00:00 2001 From: Amos You Date: Thu, 15 Feb 2024 12:16:50 -0800 Subject: [PATCH 2/2] implementation of eve optimizer as wrapper --- optax/_src/wrappers.py | 61 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index 9d0d50bc3..745cad406 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -24,7 +24,9 @@ from optax._src import base from optax._src import numerics +from optax._src import utils from optax.tree_utils import _state_utils +import optax.tree_utils as tu Array = jnp.ndarray @@ -610,3 +612,62 @@ def reject_update(_): numerics.safe_int32_increment(state.step)) return base.GradientTransformationExtraArgs(init_fn, update_fn) + + +class EveState(NamedTuple): + """Maintains inner transform state and adds a step counter. + + Attributes: + inner_state: the state of the wrapped optimizer. + step: the counter for current step (t). + f_prev: the previous loss value. + """ + inner_state: base.OptState + step: Union[jax.Array, int] + f_prev: Union[jax.Array, float] + d_tilde_prev: Union[jax.Array, float] + + +def eve( + inner: base.GradientTransformation, + b3: float, + c: float, + f: Union[jax.Array, float], + f_star: Union[jax.Array, float] +) -> base.GradientTransformation: + """Eve optimizer. + + Args: + inner: the inner transformation. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes. + f: the current loss value. + f_star: the estimated global minimum. + + Returns: + New ``GradientTransformation``. + """ + + def init_fn(params): + return EveState( + inner_state=inner.init(params), + step=0, + f_prev=f + ) + + def update_fn(updates, state, params=None): + del params + step = utils.numerics.safe_int32_increment(state.step) + d = (jnp.abs(f - state.f_prev) / + (jnp.min(jnp.array([f, state.f_prev])) - f_star) + ) + d_hat = jnp.clip(d, 1 / c, c) + d_tilde = jnp.where(step > 1, b3 * state.d_tilde_prev + (1 - b3) * d_hat, 1.) + + new_inner_updates, new_inner_state = inner.update(updates, state.inner_state) + new_updates = tu.tree_scalar_mul(1 / d_tilde, new_inner_updates) + return new_updates, EveState(inner_state=new_inner_state, + step=step, f_prev=f) + + return base.GradientTransformation(init_fn, update_fn) + \ No newline at end of file