Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug fix for proximal scalar optimization #1419

Merged
merged 7 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ New Features

Bug Fixes

- Fixes bug that occurs when taking the gradient of ``root`` and ``root_scalar`` with newer versions of JAX (>=0.4.34) and unpins the JAX version
- Fixes bug that occurs when taking the gradient of ``root`` and ``root_scalar`` with newer versions of JAX (>=0.4.34) and unpins the JAX version.
- Changes ``FixLambdaGauge`` constraint to now enforce zero flux surface average for lambda, instead of enforcing lambda(rho,0,0)=0 as it was incorrectly doing before.
- Fixes bug in ``softmin/softmax`` implementation.

- Fixes bug that occured when using ``ProximalProjection`` with a scalar optimization algorithm.

v0.12.3
-------
Expand Down
19 changes: 19 additions & 0 deletions desc/optimize/_constraint_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,25 @@
xopt, _ = self._update_equilibrium(x, store=False)
return self._objective.compute_scaled_error(xopt, constants[0])

def compute_scalar(self, x, constants=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs of grad and hess methods of ProximalProjection say that it computes the gradient of self.compute_scalar which this PR just adds. Should we change them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the docs are correct. compute_scalar computes the sum-of-squares error $\frac{1}{2} \Sigma f^2$, compute_grad computes the gradient of that scalar error as $f^T \cdot J$, and compute_hess computes the Hessian of that scalar error as $J^T \cdot J$.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, wrappers have different syntax. I was expecting smth more like ObjectiveFunctions grad method,

    @jit
    def grad(self, x, constants=None):
        """Compute gradient vector of self.compute_scalar wrt x."""
        if constants is None:
            constants = self.constants
        return jnp.atleast_1d(
            Derivative(self.compute_scalar, mode="grad")(x, constants).squeeze()
        )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both implementations should be equivalent in theory. @f0uriest would changing this to take the gradient of the new compute_scalar be any faster than the existing implementation?

"""Compute the sum of squares error.

Parameters
----------
x : ndarray
State vector.
constants : list
Constant parameters passed to sub-objectives.

Returns
-------
f : float
Objective function scalar value.

"""
f = jnp.sum(self.compute_scaled_error(x, constants=constants) ** 2) / 2
return f

Check warning on line 856 in desc/optimize/_constraint_wrappers.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/_constraint_wrappers.py#L855-L856

Added lines #L855 - L856 were not covered by tests

def compute_unscaled(self, x, constants=None):
"""Compute the raw value of the objective function.

Expand Down
38 changes: 38 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,44 @@ def test_no_iterations():
np.testing.assert_allclose(x0, out2["x"])


@pytest.mark.regression
@pytest.mark.optimize
def test_proximal_scalar():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a full regression test for this? Can't we just check that the function works?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose a unit test would suffice in this case now that we know what the problem was, but I think we should keep the regression test. I think this is the only test where we use a "proximal-scalar" optimization algorithm, and that seems like something worth testing.

"""Test that proximal scalar optimization works."""
# test fix for GH issue #1403

# optimize to reduce DSHAPE volume from 100 m^3 to 90 m^3
eq = desc.examples.get("DSHAPE")
optimizer = Optimizer("proximal-fmintr") # proximal scalar optimizer
R_modes = np.vstack(
(
[0, 0, 0],
eq.surface.R_basis.modes[
np.max(np.abs(eq.surface.R_basis.modes), 1) > 1, :
],
)
)
Z_modes = eq.surface.Z_basis.modes[
np.max(np.abs(eq.surface.Z_basis.modes), 1) > 1, :
]
objective = ObjectiveFunction(Volume(eq=eq, target=90)) # scalar objective function
constraints = (
FixBoundaryR(eq=eq, modes=R_modes),
FixBoundaryZ(eq=eq, modes=Z_modes),
FixIota(eq=eq),
FixPressure(eq=eq),
FixPsi(eq=eq),
ForceBalance(eq=eq), # force balance constraint for proximal projection
)
[eq], _ = optimizer.optimize(
things=eq,
objective=objective,
constraints=constraints,
verbose=3,
)
np.testing.assert_allclose(eq.compute("V")["V"], 90)


@pytest.mark.regression
@pytest.mark.slow
@pytest.mark.optimize
Expand Down
Loading