Skip to content

Commit

Permalink
Merge pull request #1104 from mathDR:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720207907
  • Loading branch information
OptaxDev committed Jan 27, 2025
2 parents 3e692cd + 905553f commit a357cf6
Show file tree
Hide file tree
Showing 8 changed files with 733 additions and 6 deletions.
7 changes: 7 additions & 0 deletions docs/api/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Experimental features and algorithms that don't meet the

.. autosummary::
acprop
ademamix
cocob
COCOBState
dadapt_adamw
Expand Down Expand Up @@ -41,6 +42,12 @@ Experimental features and algorithms that don't meet the
split_real_and_imaginary
SplitRealAndImaginaryState

AdEMAMix
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ademamix
.. autofunction:: scale_by_ademamix
.. autoclass:: ScaleByAdemamixState

Asynchronous-centering-Prop
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: acprop
Expand Down
16 changes: 16 additions & 0 deletions docs/gallery.rst
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,22 @@ Examples that make use of the :doc:`api/contrib` module.
<div class="sphx-glr-thumbnail-title">Sharpness-Aware Minimization (SAM).</div>
</div>

.. raw:: html

<div class="sphx-glr-thumbcontainer" tooltip="AdEMAMix.">

.. only:: html

.. image:: /images/examples/contrib/ademamix_rosenbrock.png
:alt:

:doc:`_collections/examples/contrib/ademamix_rosenbrock`

.. raw:: html

<div class="sphx-glr-thumbnail-title">AdEMAMix.</div>
</div>


.. raw:: html

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
431 changes: 431 additions & 0 deletions examples/contrib/rosenbrock_ademamix.ipynb

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -2481,12 +2481,12 @@ def lbfgs(
... grad, opt_state, params, value=value, grad=grad, value_fn=f
... )
... params = optax.apply_updates(params, updates)
... print('Objective function: ', f(params))
Objective function: 7.5166864
Objective function: 7.460699e-14
Objective function: 2.6505726e-28
Objective function: 0.0
Objective function: 0.0
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 7.52E+00
Objective function: 7.46E-14
Objective function: 2.65E-28
Objective function: 0.00E+00
Objective function: 0.00E+00
References:
Algorithms 7.4, 7.5 (page 199) of Nocedal et al, `Numerical Optimization
Expand Down
3 changes: 3 additions & 0 deletions optax/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

from optax.contrib._acprop import acprop
from optax.contrib._acprop import scale_by_acprop
from optax.contrib._ademamix import ademamix
from optax.contrib._ademamix import scale_by_ademamix
from optax.contrib._ademamix import ScaleByAdemamixState
from optax.contrib._cocob import cocob
from optax.contrib._cocob import COCOBState
from optax.contrib._cocob import scale_by_cocob
Expand Down
269 changes: 269 additions & 0 deletions optax/contrib/_ademamix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""AdEMAMix.
Implementation of
"THE ADEMAMIX OPTIMIZER: BETTER, FASTER, OLDER"
(https://arxiv.org/pdf/2409.03137) by Matteo Pagliardini,
Pierre Ablin and David Grangier.
"""

from typing import Any, Callable, NamedTuple, Optional, Union
import chex
import jax.numpy as jnp
import jax.tree_util as jtu
from optax._src import base
from optax._src import combine
from optax._src import numerics
from optax._src import transform
from optax._src import utils
import optax.tree_utils as otu


class ScaleByAdemamixState(NamedTuple):
"""State for the Ademamix algorithm.
Attributes:
count: iteration of the algorithm used to update the fast EMA and second
moment.
count_m2: iteration of the algorithm used to update the slow EMA and alpha.
m1: fast EMA of the first moment
m2: slow EMA of the first moment
nu: estimate of the second moment
"""

count: chex.Array # shape=(), dtype=jnp.int32.
count_m2: chex.Array # shape=(), dtype=jnp.int32.
m1: base.Updates
m2: base.Updates
nu: base.Updates


def scale_by_ademamix(
b1: float = 0.9,
b2: float = 0.999,
b3: base.ScalarOrSchedule = 0.9999,
alpha: base.ScalarOrSchedule = 6.0,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[chex.ArrayDType] = None
) -> base.GradientTransformation:
"""Scale updates according to the Ademamix algorithm.
See :func:`optax.contrib.ademamix.` for a full description of the algorithm.
References:
Pagliardini et al, `The AdEMAMix Optimizer: Better, Faster, Older
<https://arxiv.org/abs/2409.03137>`_, 2024
Args:
b1: Exponential decay rate to track the fast EMA.
b2: Exponential decay rate to track the second moment of past gradients.
b3: Exponential decay rate to track the slow EMA.
alpha: Mixing coefficient in the linear combination fo the fast and slow
EMAs.
eps: A small constant applied to denominator outside of the square root (as
in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: A small constant applied to denominator inside the square root (as
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
instance when computing (meta-)gradients through Adam.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
Returns:
The corresponding `GradientTransformation`.
"""

mu_dtype = utils.canonicalize_dtype(mu_dtype)

def init_fn(params) -> ScaleByAdemamixState:
m1 = otu.tree_zeros_like(params, dtype=mu_dtype) # fast EMA
m2 = otu.tree_zeros_like(params, dtype=mu_dtype) # slow EMA
nu = otu.tree_zeros_like(params, dtype=mu_dtype) # second moment estimate
return ScaleByAdemamixState(
count=jnp.zeros([], jnp.int32),
count_m2=jnp.zeros([], jnp.int32),
m1=m1,
m2=m2,
nu=nu,
)

def update_fn(updates, state, params=None):
del params
c_b3 = b3(state.count_m2) if callable(b3) else b3
c_alpha = alpha(state.count_m2) if callable(alpha) else alpha
m1 = otu.tree_update_moment(
updates, state.m1, b1, order=1
) # m1 = b1 * m1 + (1-b1) * updates
m2 = otu.tree_update_moment(updates, state.m2, c_b3, order=1)
nu = otu.tree_update_moment_per_elem_norm(updates, state.nu, b2, order=2)
count_inc = numerics.safe_int32_increment(state.count)
count_m2_inc = numerics.safe_int32_increment(state.count_m2)
m1_hat = otu.tree_bias_correction(m1, b1, count_inc)
# NOTE: AdEMAMix does not perform bias correction on b2 to let
# the slow EMA momentum buffer fill itself slowly.
nu_hat = otu.tree_bias_correction(nu, b2, count_inc)
updates = jtu.tree_map(
lambda m1_, m2_, v_: (
(m1_ + c_alpha * m2_) / (jnp.sqrt(v_ + eps_root) + eps)
),
m1_hat,
m2,
nu_hat,
)
return updates, ScaleByAdemamixState(
count=count_inc, count_m2=count_m2_inc, m1=m1, m2=m2, nu=nu
)

return base.GradientTransformation(init_fn, update_fn)


def ademamix(
learning_rate: base.ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.999,
b3: base.ScalarOrSchedule = 0.9999,
alpha: base.ScalarOrSchedule = 5.0,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[Any] = None,
weight_decay: float = 0.0,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
) -> base.GradientTransformation:
r"""AdEMAMix.
AdEMAMix (Adaptive EMA Mixture) is AdamW with a mixture of two momentum
terms to better take advantage of historical gradients.
Both SGD with momemtum (SGD+M) and Adam incorporate momentum using
Exponential Moving Averages (EMAs) of past gradients
Let :math:`\eta` represent the learning rate and :math:`\beta_1, \beta_2`,
:math:`\beta_3, \alpha, \varepsilon, \bar{\varepsilon}`, represent the
arguments ``b1``, ``b2``, ``b3``, ``alpha``, ``eps`` and ``eps_root``
respectively. Let :math:`\lambda` be the weight decay and :math:`\theta_t`
the parameter vector at time :math:`t`.
The ``init`` function of this optimizer initializes an internal state
:math:`S_0 := (m^{(1)}_0, m^{(2)}_0, \nu_0) = (0, 0, 0)`, representing initial
estimates for the fast and slow EMAs of the first moment along with the second
moment estimate. In practice, these values are stored as pytrees containing
all zeros, with the same shape as the model updates. At step :math:`t`,
the ``update`` function of this optimizer takes as arguments the incoming
gradients :math:`g^t`, the optimizer state :math:`S^t` and the parameters
:math:`\theta^{(t)}`. It then computes updates :math:`\theta^{(t+1)}` and the
new state :math:`S^{(t+1)}`. Thus, for :math:`t > 0`, we have,
.. math::
\begin{align*}
m_1^{(t)} &\leftarrow \beta_1 \cdot m_1^{(t-1)} + (1-\beta_1)
\cdot g^{(t)} \\
m_2^{(t)} &\leftarrow \beta_3 \cdot m_2^{(t-1)} + (1-\beta_3) \cdot
g^{(t)} \\
\nu^{(t)} &\leftarrow \beta_2 \cdot \nu^{(t-1)} + (1-\beta_2) \cdot
{g^{(t)}}^2 \\
\hat{m_1}^{(t)} &\leftarrow m_1^{(t)} / {(1-\beta_1^{(t)})} \\
\hat{\nu}^{(t)} &\leftarrow \nu^{(t)} / {(1-\beta_2^{(t)})} \\
\theta^{(t)} &\leftarrow \theta^{(t-1)} - \eta \cdot \left(
\frac{(\hat{m_1}^{(t)} + \alpha m_2^{(t)})}{\left(\sqrt{\hat{\nu}^{(t)}
+ \bar{\varepsilon}} + \varepsilon\right)} + \lambda \theta^{(t-1)}
\right).\\
S^{(t)} &\leftarrow (m_1^{(t)}, m_2^{(t)}, v^{(t)}).
\end{align*}
.. note::
AdEMAMix consists in leveraging very old gradients. Therefore,
the method is best suited to settings where the number of iterations is
important. The paper reports on this effect in Appendix C.1.5, showing how
smaller values of ``b3`` (e.g. ``b3 = 0.999``) can be better for low
iterations scenarios. Moreover, retaining gradient information over many
thousands of steps can pose a problem in domains requiring fast adaptation
to a sudden distribution shift, or general cases in which the distribution
is non-stationary.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(jnp.square(x)) # simple quadratic function
>>> solver = optax.contrib.ademamix(learning_rate=0.01)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.36E+01
Objective function: 1.35E+01
Objective function: 1.34E+01
References:
Pagliardini et al, `The AdEMAMix Optimizer: Better, Faster, Older
<https://arxiv.org/abs/2409.03137>`_, 2024
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Exponential decay rate to track the fast EMA.
b2: Exponential decay rate to track the second moment of past gradients.
b3: Exponential decay rate to track the slow EMA.
alpha: Mixing coefficient in the linear combination fo the fast and
slow EMAs.
eps: A small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: A small constant applied to denominator inside the square root (as
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
instance when computing (meta-)gradients through Adam.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
weight_decay: Strength of the weight decay regularization. Note that this
weight decay is multiplied with the learning rate. This is consistent
with other frameworks such as PyTorch, but different from
(Loshchilov et al, 2019) where the weight decay is only multiplied with
the "schedule multiplier", but not the base learning rate.
mask: A tree with same structure as (or a prefix of) the params PyTree,
or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the weight decay to, and `False` for those you want to skip. Note
that the Adam gradient transformations are applied to all parameters.
Returns:
The corresponding `GradientTransformation`.
.. seealso::
See the related functions :func:`optax.adam`, :func:`optax.nadamw`, as well
as the example :doc:`../_collections/examples/contrib/rosenbrock_ademamix`
for a use case.
"""
return combine.chain(
scale_by_ademamix(
b1=b1,
b2=b2,
b3=b3,
alpha=alpha,
eps=eps,
eps_root=eps_root,
mu_dtype=mu_dtype,
),
transform.add_decayed_weights(weight_decay, mask),
transform.scale_by_learning_rate(learning_rate),
)
1 change: 1 addition & 0 deletions optax/contrib/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
# Testing contributions coded as GradientTransformations
_MAIN_OPTIMIZERS_UNDER_TEST = [
{'opt_name': 'acprop', 'opt_kwargs': {'learning_rate': 1e-3}},
{'opt_name': 'ademamix', 'opt_kwargs': {'learning_rate': 1e-3}},
{'opt_name': 'cocob', 'opt_kwargs': {}},
{'opt_name': 'cocob', 'opt_kwargs': {'weight_decay': 1e-2}},
{'opt_name': 'dadapt_adamw', 'opt_kwargs': {'learning_rate': 1e-1}},
Expand Down

0 comments on commit a357cf6

Please sign in to comment.