diff --git a/CHANGELOG.md b/CHANGELOG.md index 081dec97b..d36274333 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ New Features - Adds ``eq_fixed`` flag to ``ToroidalFlux`` to allow for the equilibrium/QFM surface to vary during optimization, useful for single-stage optimizations. - Adds tutorial notebook showcasing QFM surface capability. - Adds ``rotate_zeta`` function to ``desc.compat`` to rotate an ``Equilibrium`` around Z axis. +- Adds an option ``scaled_termination`` (defaults to True) to all of the desc optimizers to measure the norms for ``xtol`` and ``gtol`` in the scaled norm provided by ``x_scale`` (which defaults to using an adaptive scaling based on the Jacobian or Hessian). This should make things more robust when optimizing parameters with widely different magnitudes. The old behavior can be recovered by passing ``options={"scaled_termination": False}``. Bug Fixes diff --git a/desc/optimize/aug_lagrangian.py b/desc/optimize/aug_lagrangian.py index 0adc90c27..2d194cdd5 100644 --- a/desc/optimize/aug_lagrangian.py +++ b/desc/optimize/aug_lagrangian.py @@ -189,7 +189,8 @@ def fmin_auglag( # noqa: C901 problem dimension. Set it to ``"auto"`` in order to use an automatic heuristic for choosing the initial scale. The heuristic is described in [2]_, p.143. By default uses ``"auto"``. - + - ``"scaled_termination"`` : Whether to evaluate termination criteria for + ``xtol`` and ``gtol`` in scaled / normalized units (default) or base units. Returns ------- @@ -312,18 +313,6 @@ def laghess(z, y, mu, *args): y = jnp.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0) y, mu, c = jnp.broadcast_arrays(y, mu, c) - # notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1 - omega = options.pop("omega", 1.0) - eta = options.pop("eta", 1.0) - alpha_omega = options.pop("alpha_omega", 1.0) - beta_omega = options.pop("beta_omega", 1.0) - alpha_eta = options.pop("alpha_eta", 0.1) - beta_eta = options.pop("beta_eta", 0.9) - tau = options.pop("tau", 10) - - gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol) - ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol) - L = lagfun(f, c, y, mu) g = laggrad(z, y, mu, *args) ngev += 1 @@ -338,6 +327,7 @@ def laghess(z, y, mu, *args): maxiter = setdefault(maxiter, z.size * 100) max_nfev = options.pop("max_nfev", 5 * maxiter + 1) max_dx = options.pop("max_dx", jnp.inf) + scaled_termination = options.pop("scaled_termination", True) hess_scale = isinstance(x_scale, str) and x_scale in ["hess", "auto"] if hess_scale: @@ -353,7 +343,9 @@ def laghess(z, y, mu, *args): g_h = g * d H_h = d * H * d[:, None] - g_norm = jnp.linalg.norm(g * v, ord=jnp.inf) + g_norm = jnp.linalg.norm( + (g * v * scale if scaled_termination else g * v), ord=jnp.inf + ) # conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould # scipy : norm of the scaled x, as used in scipy @@ -399,7 +391,7 @@ def laghess(z, y, mu, *args): ) subproblem = methods[tr_method] - z_norm = jnp.linalg.norm(z, ord=2) + z_norm = jnp.linalg.norm(((z * scale_inv) if scaled_termination else z), ord=2) success = None message = None step_norm = jnp.inf @@ -412,6 +404,18 @@ def laghess(z, y, mu, *args): if g_norm < gtol and constr_violation < ctol: success, message = True, STATUS_MESSAGES["gtol"] + # notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1 + omega = options.pop("omega", min(g_norm, 1e-2) if scaled_termination else 1.0) + eta = options.pop("eta", min(constr_violation, 1e-2) if scaled_termination else 1.0) + alpha_omega = options.pop("alpha_omega", 1.0) + beta_omega = options.pop("beta_omega", 1.0) + alpha_eta = options.pop("alpha_eta", 0.1) + beta_eta = options.pop("beta_eta", 0.9) + tau = options.pop("tau", 10) + + gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol) + ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol) + if verbose > 1: print_header_nonlinear(True, "Penalty param", "max(|mltplr|)") print_iteration_nonlinear( @@ -493,7 +497,7 @@ def laghess(z, y, mu, *args): success, message = check_termination( actual_reduction, f, - step_norm, + (step_h_norm if scaled_termination else step_norm), z_norm, g_norm, Lreduction_ratio, @@ -536,7 +540,9 @@ def laghess(z, y, mu, *args): scale, scale_inv = compute_hess_scale(H) v, dv = cl_scaling_vector(z, g, lb, ub) v = jnp.where(dv != 0, v * scale_inv, v) - g_norm = jnp.linalg.norm(g * v, ord=jnp.inf) + g_norm = jnp.linalg.norm( + (g * v * scale if scaled_termination else g * v), ord=jnp.inf + ) # updating augmented lagrangian params if g_norm < gtolk: @@ -565,9 +571,13 @@ def laghess(z, y, mu, *args): v, dv = cl_scaling_vector(z, g, lb, ub) v = jnp.where(dv != 0, v * scale_inv, v) - g_norm = jnp.linalg.norm(g * v, ord=jnp.inf) + g_norm = jnp.linalg.norm( + (g * v * scale if scaled_termination else g * v), ord=jnp.inf + ) - z_norm = jnp.linalg.norm(z, ord=2) + z_norm = jnp.linalg.norm( + ((z * scale_inv) if scaled_termination else z), ord=2 + ) d = v**0.5 * scale diag_h = g * dv * scale g_h = g * d @@ -580,7 +590,7 @@ def laghess(z, y, mu, *args): success, message = False, STATUS_MESSAGES["callback"] else: - step_norm = actual_reduction = 0 + step_norm = step_h_norm = actual_reduction = 0 iteration += 1 if verbose > 1: diff --git a/desc/optimize/aug_lagrangian_ls.py b/desc/optimize/aug_lagrangian_ls.py index 6b0952bf4..2781ac674 100644 --- a/desc/optimize/aug_lagrangian_ls.py +++ b/desc/optimize/aug_lagrangian_ls.py @@ -173,6 +173,8 @@ def lsq_auglag( # noqa: C901 value decomposition. ``"cho"`` is generally the fastest for large systems, especially on GPU, but may be less accurate for badly scaled systems. ``"svd"`` is the most accurate but significantly slower. Default ``"qr"``. + - ``"scaled_termination"`` : Whether to evaluate termination criteria for + ``xtol`` and ``gtol`` in scaled / normalized units (default) or base units. Returns ------- @@ -254,18 +256,6 @@ def lagjac(z, y, mu, *args): y = jnp.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0) y, mu, c = jnp.broadcast_arrays(y, mu, c) - # notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1 - omega = options.pop("omega", 1.0) - eta = options.pop("eta", 1.0) - alpha_omega = options.pop("alpha_omega", 1.0) - beta_omega = options.pop("beta_omega", 1.0) - alpha_eta = options.pop("alpha_eta", 0.1) - beta_eta = options.pop("beta_eta", 0.9) - tau = options.pop("tau", 10) - - gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol) - ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol) - L = lagfun(f, c, y, mu) J = lagjac(z, y, mu, *args) Lcost = 1 / 2 * jnp.dot(L, L) @@ -276,6 +266,7 @@ def lagjac(z, y, mu, *args): maxiter = setdefault(maxiter, z.size * 100) max_nfev = options.pop("max_nfev", 5 * maxiter + 1) max_dx = options.pop("max_dx", jnp.inf) + scaled_termination = options.pop("scaled_termination", True) jac_scale = isinstance(x_scale, str) and x_scale in ["jac", "auto"] if jac_scale: @@ -291,7 +282,9 @@ def lagjac(z, y, mu, *args): g_h = g * d J_h = J * d - g_norm = jnp.linalg.norm(g * v, ord=jnp.inf) + g_norm = jnp.linalg.norm( + (g * v * scale if scaled_termination else g * v), ord=jnp.inf + ) # conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould # scipy : norm of the scaled x, as used in scipy @@ -332,7 +325,7 @@ def lagjac(z, y, mu, *args): callback = setdefault(callback, lambda *args: False) - z_norm = jnp.linalg.norm(z, ord=2) + z_norm = jnp.linalg.norm(((z * scale_inv) if scaled_termination else z), ord=2) success = None message = None step_norm = jnp.inf @@ -345,6 +338,18 @@ def lagjac(z, y, mu, *args): if g_norm < gtol and constr_violation < ctol: success, message = True, STATUS_MESSAGES["gtol"] + # notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1 + omega = options.pop("omega", min(g_norm, 1e-2) if scaled_termination else 1.0) + eta = options.pop("eta", min(constr_violation, 1e-2) if scaled_termination else 1.0) + alpha_omega = options.pop("alpha_omega", 1.0) + beta_omega = options.pop("beta_omega", 1.0) + alpha_eta = options.pop("alpha_eta", 0.1) + beta_eta = options.pop("beta_eta", 0.9) + tau = options.pop("tau", 10) + + gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol) + ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol) + if verbose > 1: print_header_nonlinear(True, "Penalty param", "max(|mltplr|)") print_iteration_nonlinear( @@ -454,7 +459,7 @@ def lagjac(z, y, mu, *args): success, message = check_termination( actual_reduction, cost, - step_norm, + (step_h_norm if scaled_termination else step_norm), z_norm, g_norm, Lreduction_ratio, @@ -492,7 +497,9 @@ def lagjac(z, y, mu, *args): scale, scale_inv = compute_jac_scale(J, scale_inv) v, dv = cl_scaling_vector(z, g, lb, ub) v = jnp.where(dv != 0, v * scale_inv, v) - g_norm = jnp.linalg.norm(g * v, ord=jnp.inf) + g_norm = jnp.linalg.norm( + (g * v * scale if scaled_termination else g * v), ord=jnp.inf + ) # updating augmented lagrangian params if g_norm < gtolk: @@ -516,9 +523,13 @@ def lagjac(z, y, mu, *args): v, dv = cl_scaling_vector(z, g, lb, ub) v = jnp.where(dv != 0, v * scale_inv, v) - g_norm = jnp.linalg.norm(g * v, ord=jnp.inf) + g_norm = jnp.linalg.norm( + (g * v * scale if scaled_termination else g * v), ord=jnp.inf + ) - z_norm = jnp.linalg.norm(z, ord=2) + z_norm = jnp.linalg.norm( + ((z * scale_inv) if scaled_termination else z), ord=2 + ) d = v**0.5 * scale diag_h = g * dv * scale g_h = g * d @@ -531,7 +542,7 @@ def lagjac(z, y, mu, *args): success, message = False, STATUS_MESSAGES["callback"] else: - step_norm = actual_reduction = 0 + step_norm = step_h_norm = actual_reduction = 0 iteration += 1 if verbose > 1: diff --git a/desc/optimize/fmin_scalar.py b/desc/optimize/fmin_scalar.py index cba5d7ef0..fda78fc63 100644 --- a/desc/optimize/fmin_scalar.py +++ b/desc/optimize/fmin_scalar.py @@ -159,6 +159,8 @@ def fmintr( # noqa: C901 problem dimension. Set it to ``"auto"`` in order to use an automatic heuristic for choosing the initial scale. The heuristic is described in [2]_, p.143. By default uses ``"auto"``. + - ``"scaled_termination"`` : Whether to evaluate termination criteria for + ``xtol`` and ``gtol`` in scaled / normalized units (default) or base units. Returns ------- @@ -222,6 +224,7 @@ def fmintr( # noqa: C901 maxiter = N * 100 max_nfev = options.pop("max_nfev", 5 * maxiter + 1) max_dx = options.pop("max_dx", jnp.inf) + scaled_termination = options.pop("scaled_termination", True) hess_scale = isinstance(x_scale, str) and x_scale in ["hess", "auto"] if hess_scale: @@ -237,7 +240,9 @@ def fmintr( # noqa: C901 g_h = g * d H_h = d * H * d[:, None] - g_norm = jnp.linalg.norm(g * v, ord=jnp.inf) + g_norm = jnp.linalg.norm( + (g * v * scale if scaled_termination else g * v), ord=jnp.inf + ) # conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould # scipy : norm of the scaled x, as used in scipy @@ -283,7 +288,7 @@ def fmintr( # noqa: C901 ) subproblem = methods[tr_method] - x_norm = jnp.linalg.norm(x, ord=2) + x_norm = jnp.linalg.norm(((x * scale_inv) if scaled_termination else x), ord=2) success = None message = None step_norm = jnp.inf @@ -366,7 +371,7 @@ def fmintr( # noqa: C901 success, message = check_termination( actual_reduction, f, - step_norm, + (step_h_norm if scaled_termination else step_norm), x_norm, g_norm, reduction_ratio, @@ -410,8 +415,12 @@ def fmintr( # noqa: C901 g_h = g * d H_h = d * H * d[:, None] - x_norm = jnp.linalg.norm(x, ord=2) - g_norm = jnp.linalg.norm(g * v, ord=jnp.inf) + x_norm = jnp.linalg.norm( + ((x * scale_inv) if scaled_termination else x), ord=2 + ) + g_norm = jnp.linalg.norm( + (g * v * scale if scaled_termination else g * v), ord=jnp.inf + ) if g_norm < gtol: success, message = True, STATUS_MESSAGES["gtol"] @@ -421,7 +430,7 @@ def fmintr( # noqa: C901 allx.append(x) else: - step_norm = actual_reduction = 0 + step_norm = step_h_norm = actual_reduction = 0 iteration += 1 if verbose > 1: diff --git a/desc/optimize/least_squares.py b/desc/optimize/least_squares.py index ef19c346c..56e7d6e0b 100644 --- a/desc/optimize/least_squares.py +++ b/desc/optimize/least_squares.py @@ -139,6 +139,8 @@ def lsqtr( # noqa: C901 value decomposition. ``"cho"`` is generally the fastest for large systems, especially on GPU, but may be less accurate for badly scaled systems. ``"svd"`` is the most accurate but significantly slower. Default ``"qr"``. + - ``"scaled_termination"`` : Whether to evaluate termination criteria for + ``xtol`` and ``gtol`` in scaled / normalized units (default) or base units. Returns ------- @@ -183,6 +185,7 @@ def lsqtr( # noqa: C901 maxiter = setdefault(maxiter, n * 100) max_nfev = options.pop("max_nfev", 5 * maxiter + 1) max_dx = options.pop("max_dx", jnp.inf) + scaled_termination = options.pop("scaled_termination", True) jac_scale = isinstance(x_scale, str) and x_scale in ["jac", "auto"] if jac_scale: @@ -198,7 +201,9 @@ def lsqtr( # noqa: C901 g_h = g * d J_h = J * d - g_norm = jnp.linalg.norm(g * v, ord=jnp.inf) + g_norm = jnp.linalg.norm( + (g * v * scale if scaled_termination else g * v), ord=jnp.inf + ) # conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould # scipy : norm of the scaled x, as used in scipy @@ -239,7 +244,7 @@ def lsqtr( # noqa: C901 callback = setdefault(callback, lambda *args: False) - x_norm = jnp.linalg.norm(x, ord=2) + x_norm = jnp.linalg.norm(((x * scale_inv) if scaled_termination else x), ord=2) success = None message = None step_norm = jnp.inf @@ -344,11 +349,11 @@ def lsqtr( # noqa: C901 ) alltr.append(trust_radius) alpha *= tr_old / trust_radius - # TODO (#1395): does this need to move to the outer loop? + success, message = check_termination( actual_reduction, cost, - step_norm, + (step_h_norm if scaled_termination else step_norm), x_norm, g_norm, reduction_ratio, @@ -386,8 +391,12 @@ def lsqtr( # noqa: C901 g_h = g * d J_h = J * d - x_norm = jnp.linalg.norm(x, ord=2) - g_norm = jnp.linalg.norm(g * v, ord=jnp.inf) + x_norm = jnp.linalg.norm( + ((x * scale_inv) if scaled_termination else x), ord=2 + ) + g_norm = jnp.linalg.norm( + (g * v * scale if scaled_termination else g * v), ord=jnp.inf + ) if g_norm < gtol: success, message = True, STATUS_MESSAGES["gtol"] @@ -396,7 +405,7 @@ def lsqtr( # noqa: C901 success, message = False, STATUS_MESSAGES["callback"] else: - step_norm = actual_reduction = 0 + step_norm = step_h_norm = actual_reduction = 0 iteration += 1 if verbose > 1: diff --git a/tests/test_examples.py b/tests/test_examples.py index ef61a7d12..c4d423c66 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -883,7 +883,7 @@ def mirrorRatio(params): ) optimizer = Optimizer("lsq-auglag") (eq, field), _ = optimizer.optimize( - (eq, field), objective, constraints, maxiter=100, verbose=3 + (eq, field), objective, constraints, maxiter=150, verbose=3 ) eq, _ = eq.solve(objective="force", verbose=3) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 4f0cce0e0..edf3bf1a4 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -837,9 +837,9 @@ def con(x): args=(), x_scale="auto", ftol=0, - xtol=1e-6, - gtol=1e-6, - ctol=1e-6, + xtol=1e-8, + gtol=1e-8, + ctol=1e-8, verbose=3, maxiter=None, options={"initial_multipliers": "least_squares"}, @@ -854,9 +854,9 @@ def con(x): args=(), x_scale="auto", ftol=0, - xtol=1e-6, - gtol=1e-6, - ctol=1e-6, + xtol=1e-8, + gtol=1e-8, + ctol=1e-8, verbose=3, maxiter=None, options={"initial_multipliers": "least_squares", "tr_method": "cho"}, @@ -881,9 +881,9 @@ def con(x): args=(), x_scale="auto", ftol=0, - xtol=1e-6, - gtol=1e-6, - ctol=1e-6, + xtol=1e-8, + gtol=1e-8, + ctol=1e-8, verbose=3, maxiter=None, options={"initial_multipliers": "least_squares"},