Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug fix for proximal scalar optimization #1419

Merged
merged 7 commits into from
Dec 4, 2024
Merged

bug fix for proximal scalar optimization #1419

merged 7 commits into from
Dec 4, 2024

Conversation

ddudt
Copy link
Collaborator

@ddudt ddudt commented Dec 2, 2024

Resolves #1403

@ddudt ddudt added easy Short and simple to code or review Bug fix Something was fixed labels Dec 2, 2024
@ddudt ddudt self-assigned this Dec 2, 2024
Copy link

codecov bot commented Dec 2, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 95.60%. Comparing base (bb14269) to head (5bd497f).
Report is 8 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1419      +/-   ##
==========================================
+ Coverage   95.57%   95.60%   +0.02%     
==========================================
  Files          96       96              
  Lines       24523    24526       +3     
==========================================
+ Hits        23438    23447       +9     
+ Misses       1085     1079       -6     
Files with missing lines Coverage Δ
desc/optimize/_constraint_wrappers.py 97.22% <100.00%> (+1.98%) ⬆️

... and 3 files with indirect coverage changes

Copy link
Contributor

github-actions bot commented Dec 2, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_midres         |     +1.16 +/- 6.82     | +6.92e-03 +/- 4.06e-02 |  6.02e-01 +/- 3.1e-02  |  5.95e-01 +/- 2.6e-02  |
 test_build_transform_fft_highres        |     -0.17 +/- 1.97     | -1.63e-03 +/- 1.88e-02 |  9.54e-01 +/- 1.1e-02  |  9.56e-01 +/- 1.6e-02  |
 test_equilibrium_init_lowres            |     +0.12 +/- 1.28     | +4.52e-03 +/- 4.82e-02 |  3.78e+00 +/- 4.1e-02  |  3.78e+00 +/- 2.4e-02  |
 test_objective_compile_atf              |     -0.07 +/- 3.96     | -5.10e-03 +/- 3.09e-01 |  7.80e+00 +/- 2.5e-01  |  7.80e+00 +/- 1.8e-01  |
 test_objective_compute_atf              |     -0.17 +/- 2.37     | -1.79e-05 +/- 2.49e-04 |  1.05e-02 +/- 1.9e-04  |  1.05e-02 +/- 1.6e-04  |
 test_objective_jac_atf                  |     -0.28 +/- 2.34     | -5.18e-03 +/- 4.41e-02 |  1.88e+00 +/- 3.8e-02  |  1.88e+00 +/- 2.2e-02  |
 test_perturb_1                          |     +0.46 +/- 2.56     | +6.37e-02 +/- 3.57e-01 |  1.40e+01 +/- 1.9e-01  |  1.40e+01 +/- 3.0e-01  |
 test_proximal_jac_atf                   |     -0.11 +/- 1.49     | -9.19e-03 +/- 1.21e-01 |  8.09e+00 +/- 9.3e-02  |  8.10e+00 +/- 7.8e-02  |
 test_proximal_freeb_compute             |     -0.55 +/- 1.59     | -1.08e-03 +/- 3.12e-03 |  1.95e-01 +/- 2.8e-03  |  1.96e-01 +/- 1.4e-03  |
 test_solve_fixed_iter_compiled          |     +0.11 +/- 0.87     | +1.86e-02 +/- 1.45e-01 |  1.67e+01 +/- 6.3e-02  |  1.66e+01 +/- 1.3e-01  |
 test_build_transform_fft_lowres         |     +2.57 +/- 5.15     | +1.37e-02 +/- 2.75e-02 |  5.47e-01 +/- 1.7e-02  |  5.34e-01 +/- 2.2e-02  |
 test_equilibrium_init_medres            |     +0.97 +/- 3.46     | +3.99e-02 +/- 1.42e-01 |  4.16e+00 +/- 8.8e-02  |  4.12e+00 +/- 1.1e-01  |
 test_equilibrium_init_highres           |     +0.35 +/- 2.26     | +1.90e-02 +/- 1.22e-01 |  5.42e+00 +/- 4.3e-02  |  5.41e+00 +/- 1.1e-01  |
 test_objective_compile_dshape_current   |     -0.27 +/- 6.24     | -1.03e-02 +/- 2.41e-01 |  3.86e+00 +/- 2.3e-01  |  3.87e+00 +/- 7.5e-02  |
 test_objective_compute_dshape_current   |     -0.54 +/- 1.47     | -2.00e-05 +/- 5.40e-05 |  3.66e-03 +/- 4.4e-05  |  3.68e-03 +/- 3.1e-05  |
 test_objective_jac_dshape_current       |     +4.59 +/- 7.31     | +1.79e-03 +/- 2.84e-03 |  4.07e-02 +/- 2.0e-03  |  3.89e-02 +/- 2.0e-03  |
 test_perturb_2                          |     +2.62 +/- 1.51     | +5.01e-01 +/- 2.89e-01 |  1.96e+01 +/- 1.5e-01  |  1.91e+01 +/- 2.5e-01  |
 test_proximal_freeb_jac                 |     +1.18 +/- 1.63     | +8.80e-02 +/- 1.21e-01 |  7.54e+00 +/- 5.4e-02  |  7.45e+00 +/- 1.1e-01  |
 test_solve_fixed_iter                   |     +2.32 +/- 2.05     | +6.52e-01 +/- 5.76e-01 |  2.88e+01 +/- 2.6e-01  |  2.81e+01 +/- 5.2e-01  |
 test_LinearConstraintProjection_build   |     -0.36 +/- 2.06     | -8.25e-02 +/- 4.67e-01 |  2.26e+01 +/- 4.4e-01  |  2.26e+01 +/- 1.7e-01  |

@@ -836,6 +836,25 @@ def compute_scaled_error(self, x, constants=None):
xopt, _ = self._update_equilibrium(x, store=False)
return self._objective.compute_scaled_error(xopt, constants[0])

def compute_scalar(self, x, constants=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs of grad and hess methods of ProximalProjection say that it computes the gradient of self.compute_scalar which this PR just adds. Should we change them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the docs are correct. compute_scalar computes the sum-of-squares error $\frac{1}{2} \Sigma f^2$, compute_grad computes the gradient of that scalar error as $f^T \cdot J$, and compute_hess computes the Hessian of that scalar error as $J^T \cdot J$.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, wrappers have different syntax. I was expecting smth more like ObjectiveFunctions grad method,

    @jit
    def grad(self, x, constants=None):
        """Compute gradient vector of self.compute_scalar wrt x."""
        if constants is None:
            constants = self.constants
        return jnp.atleast_1d(
            Derivative(self.compute_scalar, mode="grad")(x, constants).squeeze()
        )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both implementations should be equivalent in theory. @f0uriest would changing this to take the gradient of the new compute_scalar be any faster than the existing implementation?

@@ -324,6 +324,44 @@ def test_no_iterations():
np.testing.assert_allclose(x0, out2["x"])


@pytest.mark.regression
@pytest.mark.optimize
def test_proximal_scalar():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a full regression test for this? Can't we just check that the function works?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose a unit test would suffice in this case now that we know what the problem was, but I think we should keep the regression test. I think this is the only test where we use a "proximal-scalar" optimization algorithm, and that seems like something worth testing.

@ddudt ddudt requested review from YigitElma and f0uriest December 3, 2024 15:38
@f0uriest f0uriest merged commit 4e723e4 into master Dec 4, 2024
25 checks passed
@f0uriest f0uriest deleted the dd/hotfix branch December 4, 2024 04:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bug fix Something was fixed easy Short and simple to code or review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TracerArrayConversionError with proximal scalar optimization
5 participants