Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gauss-Newton and Levenberg-Marquardt #920

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
6 changes: 6 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from optax._src.alias import sgd
from optax._src.alias import sm3
from optax._src.alias import yogi
from optax._src.alias import gauss_newton
from optax._src.alias import levenberg_marquardt
from optax._src.base import EmptyState
from optax._src.base import GradientTransformation
from optax._src.base import GradientTransformationExtraArgs
Expand Down Expand Up @@ -111,6 +113,7 @@
from optax._src.transform import scale_by_distance_over_gradients
from optax._src.transform import scale_by_learning_rate
from optax._src.transform import scale_by_lion
from optax._src.transform import scale_by_madsen_trust_region
from optax._src.transform import scale_by_novograd
from optax._src.transform import scale_by_optimistic_gradient
from optax._src.transform import scale_by_param_block_norm
Expand Down Expand Up @@ -271,6 +274,8 @@
"FactoredState",
"flatten",
"fromage",
"gauss_newton",
"levenberg_marquardt",
"global_norm",
"GradientTransformation",
"GradientTransformationExtraArgs",
Expand Down Expand Up @@ -334,6 +339,7 @@
"scale_by_belief",
"scale_by_lion",
"scale_by_factored_rms",
"scale_by_madsen_trust_region",
"scale_by_novograd",
"scale_by_param_block_norm",
"scale_by_param_block_rms",
Expand Down
66 changes: 66 additions & 0 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Callable, Optional, Union

import jax.numpy as jnp
import jax.scipy as jsp

from optax._src import base
from optax._src import clipping
Expand Down Expand Up @@ -1844,3 +1845,68 @@ def polyak_sgd(
max_learning_rate=max_learning_rate, f_min=f_min, eps=eps
),
)


def gauss_newton(
linear_solver: Callable = jsp.sparse.linalg.cg,
is_compositional: bool = False,
use_normal_eqs: bool = True,
) -> base.GradientTransformationExtraArgs:
"""The Gauss-Newton optimizer.

Apply the Gauss-Newton method to a compositional
problem.

Args:
is_compositional: if true solve a compositional problem (needs outer_hvp),
else solve a classical least squares.
use_normal_eqs: if true solve the normal equations.
linear_solver: instance of linear solver (e.g. jsp.sparse.linalg.cg).
Returns:
The Gauss-Newton update.
"""
return transform.scale_by_gauss_newton(
linear_solver=linear_solver,
is_compositional=is_compositional,
use_normal_eqs=use_normal_eqs,
)


def levenberg_marquardt(
is_compositional: bool = False,
use_normal_eqs: bool = True,
linear_solver: Callable = jsp.sparse.linalg.cg,
init_damping_parameter: float = 1e-3,
increase_factor: float = 2.0,
max_steps: int = 30,
) -> base.GradientTransformationExtraArgs:
"""The Levenberg-Marquardt optimizer.

Apply the gain ratio trust-region search to the regularized Gauss-Newton step.
See algorithm 6.18 in “Introduction to Optimization and Data Fitting” by
K. Madsen & H. B. Nielsen.

Args:
is_compositional: if true solve a compositional problem (needs outer_hvp),
else solve a classical least squares.
use_normal_eqs: if true solve the normal equations.
linear_solver: instance of linear solver (e.g. jsp.sparse.linalg.cg).
init_damping_parameter: initial value of the damping parameter.
increase_factor: initial value of the increasing factor.
max_steps: maximum number of steps before stopping the search loop.
Returns:
The Gauss-Newton update.
"""

opt = transform.scale_by_gauss_newton(
linear_solver=linear_solver,
is_compositional=is_compositional,
use_normal_eqs=use_normal_eqs,
)

return transform.scale_by_madsen_trust_region(
gn_optimizer=opt,
init_damping_parameter=init_damping_parameter,
increase_factor=increase_factor,
max_steps=max_steps
)
223 changes: 223 additions & 0 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from optax._src import numerics
from optax._src import utils
from optax._src import wrappers
from optax._src import update as optax_update

abs_sq = numerics.abs_sq

Expand Down Expand Up @@ -1409,6 +1410,228 @@ def update_fn(
return base.GradientTransformationExtraArgs(_init_empty_state, update_fn)


class GaussNewtonState(NamedTuple):
"""State for scale_by_gauss_newton."""
count: chex.Array

def scale_by_gauss_newton(
linear_solver: Callable = jax.scipy.sparse.linalg.cg,
is_compositional: bool = False,
use_normal_eqs: bool = True,
) -> base.GradientTransformationExtraArgs:
"""Return the Gauss-Newton updates.

Apply the Gauss-Newton method to a nonlinear least square problem or to a
more general compositional problem.

Args:
linear_solver: solver that given a function matvec that computes
matvec(x) = Ax and a pytree b solves Ax=b.
is_compositional: whether to solve a classical nonlinear least squares
problem or a compositional problem.
use_normal_eqs: if true solve the normal equations.
Returns:
The Gauss-Newton update.
"""
def init_fn(params):
del params
return GaussNewtonState(count=jnp.zeros([], jnp.int32))

def _make_ridge_gnvp(matvec: Callable, ridge: float = 0.0):
"""Returns the operator equivalent to the sum of matvec and ridge*I."""
def ridge_matvec(v: Any) -> Any:
return otu.tree_add_scalar_mul(matvec(v), ridge, v)
return ridge_matvec

def _build_gnvp(residuals, params, inner_jvp,
outer_grad, outer_hvp, damping_parameter):
"""Builds the matrix and the vector needed for the linear system."""
inner_vjp_ = jax.linear_transpose(inner_jvp, params)
inner_vjp = lambda x: inner_vjp_(x)[0]
if use_normal_eqs:
if is_compositional:
gnvp_fn = lambda x: inner_vjp(outer_hvp(inner_jvp(x)))
grad = inner_vjp(outer_grad)
else:
gnvp_fn = lambda x: inner_vjp(inner_jvp(x))
grad = inner_vjp(residuals)
gnvp_fn = _make_ridge_gnvp(gnvp_fn, ridge=damping_parameter)
else:
raise ValueError('Normal equations are still work in progress.')
return gnvp_fn, grad

def update_fn(residuals, state, params, *, inner_jvp, damping_parameter=0.,
outer_grad=None, outer_hvp=None):
"""Return the Gauss-Newton updates.

Args:
residuals: the value of the residuals (inner function) computed at params.
state: the state of the transformation.
params: the parameters of the model.
inner_jvp: a function that computes v -> J v (where J is the Jacobian of
the inner function).
mu: the damping parameter.
outer_grad: the gradient of the outer function computed at residuals.
outer_hvp: a function that computes v -> H v (where H is the Hessian of
the outer function in compositional problems).
**extra_args: additional keyword arguments. They are ignored by this
transformation.
Returns:
The Gauss-Newton update.
"""

# build gnvp and gradient
matvec, b = _build_gnvp(residuals, params, inner_jvp,
outer_grad, outer_hvp, damping_parameter)

# solve linear system
updates = linear_solver(matvec, otu.tree_scalar_mul(-1, b))[0]

count_inc = utils.safe_int32_increment(state.count)
return updates, GaussNewtonState(count=count_inc)

return base.GradientTransformationExtraArgs(init_fn, update_fn)


class ScaleByMadsenTrustRegionState(NamedTuple):
"""State for scale_by_madsen_trust_region"""
damping_parameter: float
increase_factor: float
gn_optimizer_state: base.OptState
accepted: bool
iter_num: int
value: Union[float, jax.Array]

def scale_by_madsen_trust_region(
gn_optimizer: base.GradientTransformation,
init_damping_parameter: float = 1e-3,
increase_factor: float = 2.0,
max_steps: int = 30,
) -> base.GradientTransformationExtraArgs:
"""Return the Gauss-Newton updates that satify the gain ratio test.

Modify the damping parameter of the GaussNewton optimizer based on the
algorithm 6.18 provided by K. Madsen & H. B. Nielsen in the book
“Introduction to Optimization and Data Fitting”.

Args:
gn_optimizer: instance of scale_by_gauss_newton GradientTransformation.
init_damping_parameter: initial value for the damping parameter.
increase_factor: initial value for the increase factor.
max_steps: maximum number of iterations before stopping the search loop.
Returns:
The Gauss-Newton update.
"""
def init_fn(params: base.Params) -> ScaleByMadsenTrustRegionState:
return ScaleByMadsenTrustRegionState(
damping_parameter=init_damping_parameter,
increase_factor=increase_factor,
gn_optimizer_state=gn_optimizer.init(params),
accepted=False,
iter_num=jnp.zeros([], jnp.int32),
value=jnp.array(jnp.inf),
)

def _gain_ratio(value, value_new, updates, grad, mu):
gain_ratio_denom = 0.5 * otu.tree_vdot(updates,
otu.tree_sub(otu.tree_scalar_mul(mu, updates), grad))
return (value - value_new) / gain_ratio_denom

def _gain_ratio_test_true(updates, mu, nu, rho):
del nu
mu = mu * jnp.maximum(1/3, 1-(2*rho-1)**3)
nu = 2.0
accepted = True
return updates, accepted, mu, nu

def _gain_ratio_test_false(updates, mu, nu, rho):
del rho
mu = mu * nu
nu = 2 * nu
accepted = False
return otu.tree_zeros_like(updates), accepted, mu, nu

def update_fn(
search_state: ScaleByMadsenTrustRegionState,
params: base.Params,
*,
residuals_fn: Callable[..., Union[jax.Array, float]],
**extra_args: dict[str, Any],
) -> tuple[base.Updates, ScaleByMadsenTrustRegionState]:
"""Compute updates that satisfy the gain ratio test."""

# fetch arguments to be fed to residuals_fn from the extra_args
(fn_kwargs,), remaining_kwargs = utils._extract_fns_kwargs( # pylint: disable=protected-access
(residuals_fn,), extra_args
)
del remaining_kwargs
residuals_fn_ = functools.partial(residuals_fn, **fn_kwargs)

# compute value and grad for the current params
residuals, inner_jvp = jax.linearize(residuals_fn_, params)
value_fn = lambda x: 0.5*jnp.sum(residuals_fn_(x)**2)
value, grad = jax.value_and_grad(value_fn)(params)

def cond_fn(val) -> Union[int, jax._src.basearray.Array]:
updates, search_state = val
del updates
accepted = search_state.accepted
iter_num = search_state.iter_num
return (~accepted) & (iter_num <= max_steps)

def body_fn(val) -> ScaleByMadsenTrustRegionState:
updates, search_state = val
damping_parameter = search_state.damping_parameter
increase_factor = search_state.increase_factor
value = search_state.value
iter_num = search_state.iter_num
opt_state = search_state.gn_optimizer_state

# compute GN update with current damping parameter
updates_new, opt_state = gn_optimizer.update(residuals, opt_state, params,
inner_jvp=inner_jvp,
damping_parameter=damping_parameter)
value_new = value_fn(optax_update.apply_updates(params, updates_new))

# apply gain ratio test
rho = _gain_ratio(value, value_new, updates, grad, damping_parameter)
updates_new, accepted, damping_parameter, increase_factor = jax.lax.cond(
rho > 0,
_gain_ratio_test_true,
_gain_ratio_test_false,
updates_new,
damping_parameter,
increase_factor, rho,
)

iter_num_inc = utils.safe_int32_increment(iter_num)
search_state = ScaleByMadsenTrustRegionState(
damping_parameter=damping_parameter,
increase_factor=increase_factor,
gn_optimizer_state=opt_state,
accepted=accepted,
iter_num=iter_num_inc,
value=value,
)
return updates_new, search_state

search_state = ScaleByMadsenTrustRegionState(
damping_parameter=search_state.damping_parameter,
increase_factor=search_state.increase_factor,
gn_optimizer_state=search_state.gn_optimizer_state,
accepted=False,
iter_num=jnp.zeros([], jnp.int32),
value=value,
)

# start search for damping parameter
updates, search_state = jax.lax.while_loop(cond_fn, body_fn,
(otu.tree_zeros_like(params), search_state))
return updates, search_state

return base.GradientTransformationExtraArgs(init_fn, update_fn)


### Legacy symbols to be removed. ###


Expand Down
Loading