Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rescale optimizer termination criteria #1073

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
191ee2f
Rescale optimizer termination criteria
f0uriest Jun 25, 2024
2b56bcc
Merge branch 'master' into rc/stopping_criteria
f0uriest Jul 9, 2024
97743d6
Fix scaling of g_norm
f0uriest Jul 9, 2024
edfea84
Update default augmented lagrangian hyperparams for scaled termination
f0uriest Jul 9, 2024
ad85765
Merge branch 'master' into rc/stopping_criteria
f0uriest Jul 9, 2024
93c1404
Merge branch 'master' into rc/stopping_criteria
f0uriest Aug 27, 2024
245ecf4
Merge branch 'master' into rc/stopping_criteria
f0uriest Aug 29, 2024
cc73f99
Merge branch 'master' into rc/stopping_criteria
f0uriest Sep 27, 2024
90a0c0a
Merge branch 'master' into rc/stopping_criteria
dpanici Sep 27, 2024
dd40060
Merge branch 'master' into rc/stopping_criteria
dpanici Oct 30, 2024
b6a2298
Merge branch 'master' into rc/stopping_criteria
YigitElma Nov 1, 2024
9c2a806
Merge branch 'master' into rc/stopping_criteria
f0uriest Nov 12, 2024
554663f
Merge branch 'master' into rc/stopping_criteria
f0uriest Nov 12, 2024
396e1c5
Merge branch 'rc/stopping_criteria' of github.com:PlasmaControl/DESC …
f0uriest Nov 12, 2024
900f06b
Few more tweaks
f0uriest Nov 13, 2024
b9f2245
Merge branch 'rc/coil_test' into rc/stopping_criteria
f0uriest Nov 13, 2024
ff40b3f
Merge branch 'master' into rc/stopping_criteria
f0uriest Nov 13, 2024
6199fc8
Merge branch 'master' into rc/stopping_criteria
YigitElma Nov 15, 2024
1711c5e
Merge branch 'master' into rc/stopping_criteria
f0uriest Nov 19, 2024
a8f0a54
Remove todo comment
f0uriest Nov 21, 2024
fe5bd89
Update changelog
f0uriest Nov 21, 2024
e390193
Merge branch 'master' into rc/stopping_criteria
dpanici Nov 21, 2024
24da16f
Merge branch 'master' into rc/stopping_criteria
YigitElma Nov 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ New Features
* use of both this and the ``QuadraticFlux`` objective allows for REGCOIL solutions to be obtained through the optimization framework, and combined with other objectives as well.
- Changes local area weighting of Bn in QuadraticFlux objective to be the square root of the local area element (Note that any existing optimizations using this objective may need different weights to achieve the same result now.)
- Adds a new tutorial showing how to use``REGCOIL`` features.
- Adds an option ``scaled_termination`` (defaults to True) to all of the desc optimizers to measure the norms for ``xtol`` and ``gtol`` in the scaled norm provided by ``x_scale`` (which defaults to using an adaptive scaling based on the Jacobian or Hessian). This should make things more robust when optimizing parameters with widely different magnitudes. The old behavior can be recovered by passing ``options={"scaled_termination": False}``.


Bug Fixes
Expand Down
50 changes: 30 additions & 20 deletions desc/optimize/aug_lagrangian.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def fmin_auglag( # noqa: C901
problem dimension. Set it to ``"auto"`` in order to use an automatic heuristic
for choosing the initial scale. The heuristic is described in [2]_, p.143.
By default uses ``"auto"``.

- ``"scaled_termination"`` : Whether to evaluate termination criteria for
``xtol`` and ``gtol`` in scaled / normalized units (default) or base units.

Returns
-------
Expand Down Expand Up @@ -312,18 +313,6 @@ def laghess(z, y, mu, *args):
y = jnp.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)
y, mu, c = jnp.broadcast_arrays(y, mu, c)

# notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1
omega = options.pop("omega", 1.0)
eta = options.pop("eta", 1.0)
alpha_omega = options.pop("alpha_omega", 1.0)
beta_omega = options.pop("beta_omega", 1.0)
alpha_eta = options.pop("alpha_eta", 0.1)
beta_eta = options.pop("beta_eta", 0.9)
tau = options.pop("tau", 10)

gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol)
ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol)

L = lagfun(f, c, y, mu)
g = laggrad(z, y, mu, *args)
ngev += 1
Expand All @@ -338,6 +327,7 @@ def laghess(z, y, mu, *args):
maxiter = setdefault(maxiter, z.size * 100)
max_nfev = options.pop("max_nfev", 5 * maxiter + 1)
max_dx = options.pop("max_dx", jnp.inf)
scaled_termination = options.pop("scaled_termination", True)

hess_scale = isinstance(x_scale, str) and x_scale in ["hess", "auto"]
if hess_scale:
Expand All @@ -353,7 +343,9 @@ def laghess(z, y, mu, *args):

g_h = g * d
H_h = d * H * d[:, None]
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould
# scipy : norm of the scaled x, as used in scipy
Expand Down Expand Up @@ -399,7 +391,7 @@ def laghess(z, y, mu, *args):
)
subproblem = methods[tr_method]

z_norm = jnp.linalg.norm(z, ord=2)
z_norm = jnp.linalg.norm(((z * scale_inv) if scaled_termination else z), ord=2)
success = None
message = None
step_norm = jnp.inf
Expand All @@ -412,6 +404,18 @@ def laghess(z, y, mu, *args):
if g_norm < gtol and constr_violation < ctol:
success, message = True, STATUS_MESSAGES["gtol"]

# notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1
omega = options.pop("omega", min(g_norm, 1e-2) if scaled_termination else 1.0)
eta = options.pop("eta", min(constr_violation, 1e-2) if scaled_termination else 1.0)
alpha_omega = options.pop("alpha_omega", 1.0)
beta_omega = options.pop("beta_omega", 1.0)
alpha_eta = options.pop("alpha_eta", 0.1)
beta_eta = options.pop("beta_eta", 0.9)
tau = options.pop("tau", 10)

gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol)
ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol)

if verbose > 1:
print_header_nonlinear(True, "Penalty param", "max(|mltplr|)")
print_iteration_nonlinear(
Expand Down Expand Up @@ -493,7 +497,7 @@ def laghess(z, y, mu, *args):
success, message = check_termination(
actual_reduction,
f,
step_norm,
(step_h_norm if scaled_termination else step_norm),
z_norm,
g_norm,
Lreduction_ratio,
Expand Down Expand Up @@ -536,7 +540,9 @@ def laghess(z, y, mu, *args):
scale, scale_inv = compute_hess_scale(H)
v, dv = cl_scaling_vector(z, g, lb, ub)
v = jnp.where(dv != 0, v * scale_inv, v)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# updating augmented lagrangian params
if g_norm < gtolk:
Expand Down Expand Up @@ -565,9 +571,13 @@ def laghess(z, y, mu, *args):

v, dv = cl_scaling_vector(z, g, lb, ub)
v = jnp.where(dv != 0, v * scale_inv, v)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

z_norm = jnp.linalg.norm(z, ord=2)
z_norm = jnp.linalg.norm(
((z * scale_inv) if scaled_termination else z), ord=2
)
d = v**0.5 * scale
diag_h = g * dv * scale
g_h = g * d
Expand All @@ -580,7 +590,7 @@ def laghess(z, y, mu, *args):
success, message = False, STATUS_MESSAGES["callback"]

else:
step_norm = actual_reduction = 0
step_norm = step_h_norm = actual_reduction = 0

iteration += 1
if verbose > 1:
Expand Down
49 changes: 30 additions & 19 deletions desc/optimize/aug_lagrangian_ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@
value decomposition. ``"cho"`` is generally the fastest for large systems,
especially on GPU, but may be less accurate for badly scaled systems.
``"svd"`` is the most accurate but significantly slower. Default ``"qr"``.
- ``"scaled_termination"`` : Whether to evaluate termination criteria for
``xtol`` and ``gtol`` in scaled / normalized units (default) or base units.

Returns
-------
Expand Down Expand Up @@ -254,18 +256,6 @@
y = jnp.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)
y, mu, c = jnp.broadcast_arrays(y, mu, c)

# notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1
omega = options.pop("omega", 1.0)
eta = options.pop("eta", 1.0)
alpha_omega = options.pop("alpha_omega", 1.0)
beta_omega = options.pop("beta_omega", 1.0)
alpha_eta = options.pop("alpha_eta", 0.1)
beta_eta = options.pop("beta_eta", 0.9)
tau = options.pop("tau", 10)

gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol)
ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol)

L = lagfun(f, c, y, mu)
J = lagjac(z, y, mu, *args)
Lcost = 1 / 2 * jnp.dot(L, L)
Expand All @@ -276,6 +266,7 @@
maxiter = setdefault(maxiter, z.size * 100)
max_nfev = options.pop("max_nfev", 5 * maxiter + 1)
max_dx = options.pop("max_dx", jnp.inf)
scaled_termination = options.pop("scaled_termination", True)

jac_scale = isinstance(x_scale, str) and x_scale in ["jac", "auto"]
if jac_scale:
Expand All @@ -291,7 +282,9 @@

g_h = g * d
J_h = J * d
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould
# scipy : norm of the scaled x, as used in scipy
Expand Down Expand Up @@ -332,7 +325,7 @@

callback = setdefault(callback, lambda *args: False)

z_norm = jnp.linalg.norm(z, ord=2)
z_norm = jnp.linalg.norm(((z * scale_inv) if scaled_termination else z), ord=2)
success = None
message = None
step_norm = jnp.inf
Expand All @@ -345,6 +338,18 @@
if g_norm < gtol and constr_violation < ctol:
success, message = True, STATUS_MESSAGES["gtol"]

# notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1
omega = options.pop("omega", min(g_norm, 1e-2) if scaled_termination else 1.0)
eta = options.pop("eta", min(constr_violation, 1e-2) if scaled_termination else 1.0)
alpha_omega = options.pop("alpha_omega", 1.0)
beta_omega = options.pop("beta_omega", 1.0)
alpha_eta = options.pop("alpha_eta", 0.1)
beta_eta = options.pop("beta_eta", 0.9)
tau = options.pop("tau", 10)

gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol)
ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol)

if verbose > 1:
print_header_nonlinear(True, "Penalty param", "max(|mltplr|)")
print_iteration_nonlinear(
Expand Down Expand Up @@ -454,7 +459,7 @@
success, message = check_termination(
actual_reduction,
cost,
step_norm,
(step_h_norm if scaled_termination else step_norm),
z_norm,
g_norm,
Lreduction_ratio,
Expand Down Expand Up @@ -492,7 +497,9 @@
scale, scale_inv = compute_jac_scale(J, scale_inv)
v, dv = cl_scaling_vector(z, g, lb, ub)
v = jnp.where(dv != 0, v * scale_inv, v)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# updating augmented lagrangian params
if g_norm < gtolk:
Expand All @@ -516,9 +523,13 @@

v, dv = cl_scaling_vector(z, g, lb, ub)
v = jnp.where(dv != 0, v * scale_inv, v)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

z_norm = jnp.linalg.norm(z, ord=2)
z_norm = jnp.linalg.norm(
((z * scale_inv) if scaled_termination else z), ord=2
)
d = v**0.5 * scale
diag_h = g * dv * scale
g_h = g * d
Expand All @@ -531,7 +542,7 @@
success, message = False, STATUS_MESSAGES["callback"]

else:
step_norm = actual_reduction = 0
step_norm = step_h_norm = actual_reduction = 0

Check warning on line 545 in desc/optimize/aug_lagrangian_ls.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/aug_lagrangian_ls.py#L545

Added line #L545 was not covered by tests

iteration += 1
if verbose > 1:
Expand Down
21 changes: 15 additions & 6 deletions desc/optimize/fmin_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def fmintr( # noqa: C901
problem dimension. Set it to ``"auto"`` in order to use an automatic heuristic
for choosing the initial scale. The heuristic is described in [2]_, p.143.
By default uses ``"auto"``.
- ``"scaled_termination"`` : Whether to evaluate termination criteria for
``xtol`` and ``gtol`` in scaled / normalized units (default) or base units.

Returns
-------
Expand Down Expand Up @@ -222,6 +224,7 @@ def fmintr( # noqa: C901
maxiter = N * 100
max_nfev = options.pop("max_nfev", 5 * maxiter + 1)
max_dx = options.pop("max_dx", jnp.inf)
scaled_termination = options.pop("scaled_termination", True)

hess_scale = isinstance(x_scale, str) and x_scale in ["hess", "auto"]
if hess_scale:
Expand All @@ -237,7 +240,9 @@ def fmintr( # noqa: C901

g_h = g * d
H_h = d * H * d[:, None]
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould
# scipy : norm of the scaled x, as used in scipy
Expand Down Expand Up @@ -283,7 +288,7 @@ def fmintr( # noqa: C901
)
subproblem = methods[tr_method]

x_norm = jnp.linalg.norm(x, ord=2)
x_norm = jnp.linalg.norm(((x * scale_inv) if scaled_termination else x), ord=2)
success = None
message = None
step_norm = jnp.inf
Expand Down Expand Up @@ -366,7 +371,7 @@ def fmintr( # noqa: C901
success, message = check_termination(
actual_reduction,
f,
step_norm,
(step_h_norm if scaled_termination else step_norm),
x_norm,
g_norm,
reduction_ratio,
Expand Down Expand Up @@ -410,8 +415,12 @@ def fmintr( # noqa: C901
g_h = g * d
H_h = d * H * d[:, None]

x_norm = jnp.linalg.norm(x, ord=2)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
x_norm = jnp.linalg.norm(
((x * scale_inv) if scaled_termination else x), ord=2
)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

if g_norm < gtol:
success, message = True, STATUS_MESSAGES["gtol"]
Expand All @@ -421,7 +430,7 @@ def fmintr( # noqa: C901

allx.append(x)
else:
step_norm = actual_reduction = 0
step_norm = step_h_norm = actual_reduction = 0

iteration += 1
if verbose > 1:
Expand Down
23 changes: 16 additions & 7 deletions desc/optimize/least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def lsqtr( # noqa: C901
value decomposition. ``"cho"`` is generally the fastest for large systems,
especially on GPU, but may be less accurate for badly scaled systems.
``"svd"`` is the most accurate but significantly slower. Default ``"qr"``.
- ``"scaled_termination"`` : Whether to evaluate termination criteria for
``xtol`` and ``gtol`` in scaled / normalized units (default) or base units.

Returns
-------
Expand Down Expand Up @@ -183,6 +185,7 @@ def lsqtr( # noqa: C901
maxiter = setdefault(maxiter, n * 100)
max_nfev = options.pop("max_nfev", 5 * maxiter + 1)
max_dx = options.pop("max_dx", jnp.inf)
scaled_termination = options.pop("scaled_termination", True)

jac_scale = isinstance(x_scale, str) and x_scale in ["jac", "auto"]
if jac_scale:
Expand All @@ -198,7 +201,9 @@ def lsqtr( # noqa: C901

g_h = g * d
J_h = J * d
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould
# scipy : norm of the scaled x, as used in scipy
Expand Down Expand Up @@ -239,7 +244,7 @@ def lsqtr( # noqa: C901

callback = setdefault(callback, lambda *args: False)

x_norm = jnp.linalg.norm(x, ord=2)
x_norm = jnp.linalg.norm(((x * scale_inv) if scaled_termination else x), ord=2)
success = None
message = None
step_norm = jnp.inf
Expand Down Expand Up @@ -344,11 +349,11 @@ def lsqtr( # noqa: C901
)
alltr.append(trust_radius)
alpha *= tr_old / trust_radius
# TODO (#1395): does this need to move to the outer loop?

success, message = check_termination(
YigitElma marked this conversation as resolved.
Show resolved Hide resolved
actual_reduction,
cost,
step_norm,
(step_h_norm if scaled_termination else step_norm),
x_norm,
g_norm,
reduction_ratio,
Expand Down Expand Up @@ -386,8 +391,12 @@ def lsqtr( # noqa: C901

g_h = g * d
J_h = J * d
x_norm = jnp.linalg.norm(x, ord=2)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
x_norm = jnp.linalg.norm(
((x * scale_inv) if scaled_termination else x), ord=2
)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

if g_norm < gtol:
success, message = True, STATUS_MESSAGES["gtol"]
Expand All @@ -396,7 +405,7 @@ def lsqtr( # noqa: C901
success, message = False, STATUS_MESSAGES["callback"]

else:
step_norm = actual_reduction = 0
step_norm = step_h_norm = actual_reduction = 0

iteration += 1
if verbose > 1:
Expand Down
Loading
Loading