Skip to content

Commit

Permalink
Merge pull request #1180 from carlosgmartin:polyak_sgd_plus
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721010195
  • Loading branch information
OptaxDev committed Jan 29, 2025
2 parents 225a707 + 15fd895 commit c51fbd5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
15 changes: 14 additions & 1 deletion optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -2314,6 +2314,7 @@ def polyak_sgd(
scaling: base.ScalarOrSchedule = 1.0,
f_min: float = 0.0,
eps: float = 0.0,
variant: str = 'sps',
) -> base.GradientTransformationExtraArgs:
r"""SGD with Polyak step-size.
Expand All @@ -2331,13 +2332,18 @@ def polyak_sgd(
:math:`f^\star` is a guess of the minimum value of the function set with
``f_min``.
Setting ``variant="sps+"`` (Garrigos et al. 2023) uses only the non-negative
part of the suboptimality gap. That is, it replaces :math:`f(x) - f^\star`
with :math:`(f(x) - f^\star)_+`, where :math:`a_+ = \max \{x, 0\}`.
Args:
max_learning_rate: a maximum step size to use (defaults to 1).
scaling: A global scaling factor, either fixed or evolving along iterations
with a scheduler (defaults to 1).
f_min: a lower bound on the objective function (defaults to 0). Corresponds
to :math:`f^\star` in the formula above.
eps: a value to add in the denominator of the update (defaults to 0).
variant: either ``'sps'`` or ``'sps+'`` (defaults to ``'sps'``).
Returns:
A :class:`optax.GradientTransformationExtraArgs`, where the ``update``
Expand Down Expand Up @@ -2371,6 +2377,10 @@ def polyak_sgd(
Berrada et al., `Training neural networks for and by interpolation
<https://arxiv.org/pdf/1906.05661.pdf>`_, 2020
Garrigos et al., `Function value learning: Adaptive learning rates based on
the Polyak stepsize and function splitting in ERM
<https://arxiv.org/abs/2307.14528>`_, 2023
.. warning::
This method requires knowledge of an approximate value of the of the
objective function minimum, passed through the ``f_min`` argument.
Expand All @@ -2382,7 +2392,10 @@ def polyak_sgd(
return combine.chain(
sgd(learning_rate=scaling),
transform.scale_by_polyak(
max_learning_rate=max_learning_rate, f_min=f_min, eps=eps
max_learning_rate=max_learning_rate,
f_min=f_min,
eps=eps,
variant=variant,
),
)

Expand Down
12 changes: 11 additions & 1 deletion optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import chex
import jax
from jax import nn
import jax.numpy as jnp
from optax import tree_utils as otu
from optax._src import base
Expand Down Expand Up @@ -1393,6 +1394,7 @@ def scale_by_polyak(
f_min: float = 0.0,
max_learning_rate: float = 1.0,
eps: float = 0.0,
variant: str = 'sps',
) -> base.GradientTransformationExtraArgs:
r"""Scales the update by Polyak's step-size.
Expand All @@ -1403,6 +1405,7 @@ def scale_by_polyak(
to :math:`f^\star` in the formula above.
max_learning_rate: a maximum step size to use (defaults to 1).
eps: a value to add in the denominator of the update (defaults to 0).
variant: either ``'sps'`` or ``'sps+'`` (defaults to ``'sps'``).
Returns:
A :class:`optax.GradientTransformationExtraArgs`, where the ``update``
Expand Down Expand Up @@ -1433,11 +1436,18 @@ def update_fn(
"""
del params, extra_args
grad_sq_norm = otu.tree_l2_norm(updates, squared=True)
gap = value - f_min
if variant == 'sps':
pass
elif variant == 'sps+':
gap = nn.relu(gap)
else:
raise ValueError(f'Invalid argument value for Polyak SGD: {variant=}')
# avoid division by zero
step = jnp.where(
grad_sq_norm + eps <= jnp.finfo(float).eps,
jnp.array(0.0),
jnp.minimum((value - f_min) / (grad_sq_norm + eps), max_learning_rate),
jnp.minimum(gap / (grad_sq_norm + eps), max_learning_rate),
)
updates = otu.tree_scalar_mul(step, updates)
return updates, state
Expand Down

0 comments on commit c51fbd5

Please sign in to comment.