Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Robust Coil Optimization (Wechsung) #1520

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from
Draft

Conversation

dpanici
Copy link
Collaborator

@dpanici dpanici commented Jan 16, 2025

No description provided.

@dpanici dpanici requested a review from sinaatalay January 16, 2025 17:24
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Contributor

github-actions bot commented Jan 16, 2025

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_midres         |     -0.46 +/- 5.83     | -3.04e-03 +/- 3.84e-02 |  6.56e-01 +/- 2.7e-02  |  6.59e-01 +/- 2.8e-02  |
 test_build_transform_fft_highres        |     -0.43 +/- 5.94     | -4.09e-03 +/- 5.66e-02 |  9.48e-01 +/- 5.0e-02  |  9.52e-01 +/- 2.6e-02  |
 test_equilibrium_init_lowres            |     -3.66 +/- 4.19     | -1.58e-01 +/- 1.81e-01 |  4.15e+00 +/- 1.2e-01  |  4.30e+00 +/- 1.3e-01  |
 test_objective_compile_atf              |     -1.80 +/- 1.78     | -1.57e-01 +/- 1.56e-01 |  8.58e+00 +/- 1.1e-01  |  8.74e+00 +/- 1.1e-01  |
 test_objective_compute_atf              |     -8.66 +/- 4.21     | -1.57e-03 +/- 7.64e-04 |  1.66e-02 +/- 3.5e-04  |  1.82e-02 +/- 6.8e-04  |
 test_objective_jac_atf                  |     -0.25 +/- 1.85     | -4.99e-03 +/- 3.70e-02 |  2.00e+00 +/- 3.2e-02  |  2.00e+00 +/- 1.9e-02  |
 test_perturb_1                          |     -4.37 +/- 3.93     | -7.13e-01 +/- 6.42e-01 |  1.56e+01 +/- 5.0e-01  |  1.63e+01 +/- 4.0e-01  |
 test_proximal_jac_atf                   |     +0.27 +/- 1.53     | +2.19e-02 +/- 1.25e-01 |  8.18e+00 +/- 1.1e-01  |  8.16e+00 +/- 5.2e-02  |
 test_proximal_freeb_compute             |     +1.39 +/- 1.46     | +2.56e-03 +/- 2.68e-03 |  1.87e-01 +/- 2.1e-03  |  1.84e-01 +/- 1.7e-03  |
 test_solve_fixed_iter_compiled          |     -0.62 +/- 1.41     | -1.30e-01 +/- 2.95e-01 |  2.08e+01 +/- 4.8e-02  |  2.10e+01 +/- 2.9e-01  |
 test_objective_compute_ripple           |     -0.92 +/- 2.28     | -5.98e-03 +/- 1.49e-02 |  6.47e-01 +/- 1.4e-02  |  6.53e-01 +/- 5.9e-03  |
 test_objective_grad_ripple              |     -1.25 +/- 1.73     | -3.47e-02 +/- 4.79e-02 |  2.74e+00 +/- 3.7e-02  |  2.78e+00 +/- 3.0e-02  |
 test_build_transform_fft_lowres         |     -2.75 +/- 3.46     | -1.88e-02 +/- 2.36e-02 |  6.64e-01 +/- 1.6e-02  |  6.82e-01 +/- 1.7e-02  |
 test_equilibrium_init_medres            |     -3.02 +/- 3.68     | -1.48e-01 +/- 1.80e-01 |  4.74e+00 +/- 1.7e-01  |  4.89e+00 +/- 7.2e-02  |
 test_equilibrium_init_highres           |     +0.36 +/- 2.86     | +2.06e-02 +/- 1.63e-01 |  5.71e+00 +/- 1.2e-01  |  5.69e+00 +/- 1.1e-01  |
 test_objective_compile_dshape_current   |     +1.88 +/- 5.93     | +8.15e-02 +/- 2.58e-01 |  4.43e+00 +/- 1.7e-01  |  4.34e+00 +/- 1.9e-01  |
 test_objective_compute_dshape_current   |     +1.34 +/- 2.69     | +7.36e-05 +/- 1.48e-04 |  5.57e-03 +/- 8.2e-05  |  5.50e-03 +/- 1.2e-04  |
 test_objective_jac_dshape_current       |     +2.03 +/- 7.84     | +9.15e-04 +/- 3.53e-03 |  4.59e-02 +/- 2.5e-03  |  4.50e-02 +/- 2.5e-03  |
 test_perturb_2                          |     -1.45 +/- 2.22     | -3.19e-01 +/- 4.88e-01 |  2.17e+01 +/- 1.8e-01  |  2.20e+01 +/- 4.5e-01  |
 test_proximal_freeb_jac                 |     -1.19 +/- 2.86     | -8.71e-02 +/- 2.09e-01 |  7.21e+00 +/- 6.5e-02  |  7.29e+00 +/- 2.0e-01  |
 test_solve_fixed_iter                   |     +1.09 +/- 5.58     | +3.75e-01 +/- 1.92e+00 |  3.49e+01 +/- 1.0e+00  |  3.45e+01 +/- 1.6e+00  |
 test_LinearConstraintProjection_build   |     +0.41 +/- 2.50     | +4.62e-02 +/- 2.81e-01 |  1.13e+01 +/- 2.3e-01  |  1.12e+01 +/- 1.6e-01  |
 test_objective_compute_ripple_spline    |     +0.44 +/- 1.09     | +1.36e-03 +/- 3.36e-03 |  3.10e-01 +/- 1.4e-03  |  3.09e-01 +/- 3.0e-03  |
 test_objective_grad_ripple_spline       |     +0.18 +/- 2.17     | +2.92e-03 +/- 3.45e-02 |  1.59e+00 +/- 2.4e-02  |  1.59e+00 +/- 2.4e-02  |

@sinaatalay sinaatalay force-pushed the dp/stochastic-coil branch 2 times, most recently from 64a7842 to bbb2072 Compare January 20, 2025 18:27
@@ -86,7 +86,10 @@ def _unjittable(x):
if isinstance(x, dict):
return all([_unjittable(y) or y is None for y in x.values()])
if hasattr(x, "dtype") and np.ndim(x) == 0:
return np.issubdtype(x.dtype, np.bool_) or np.issubdtype(x.dtype, np.int_)
try:
return np.issubdtype(x.dtype, np.bool_) or np.issubdtype(x.dtype, np.int_)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change was, if I recall correctly, to handle setting the key in stochastic dict to be static, I dont remember exactly the issue but I think that whatever type of object the JAX key is does not have a .dtype attribute, maybe

@dpanici
Copy link
Collaborator Author

dpanici commented Jan 22, 2025

unscented transform as an option?

@sinaatalay
Copy link
Member

Is perturbation the same in each step during optimization? It would be better if it was the same.

@sinaatalay
Copy link
Member

sinaatalay commented Feb 3, 2025

Perturbations seems to be working:

X Y Z
image image image
  • The blue line represents the position perturbation as a function of $\theta$ (ranging from $0$ to $2\pi$, covering a complete revolution of the coil).
  • The orange and green lines show the derivatives of the position perturbation with respect to $\theta$, which will be used to perturb the tangent at each point.
    • The orange line corresponds to samples drawn from a normal distribution using the covariance matrix (which is what we are going to use for the optimization)
    • The green line computed numerically through finite differences, for validation.

The perturbation is periodic (position perturbation at $\theta=0$ and perturbation at $\theta=2\pi$ is the same), a crucial aspect since the coil’s position must be identical at $\theta$ = 0 and $\theta = 2\pi$.

Here is an example of a coil and its perturbed version (large standard deviation is used for illustration):

image

Here is an optimization example (black one is the result of the regular optimization, red one is the result of the stochastic optimization):

image

Here is what the API will look like:

QuadraticFlux(
    eq,
    field=coilset,
    eval_grid=plasma_grid,
    field_grid=coil_grid,
    vacuum=True,
    weight=weights_dict["quadratic flux"],
    stochastic_optimization_settings={
        "number_of_samples": 1,
        "length_scale": 0.2,
        "standard_deviation": 0.01,
    },
)

Some questions:

  1. Do we want to add stochastic_optimization_settings key to all the coil objectives?

@@ -1044,7 +1077,7 @@ def _compute_A_or_B(
return AB

def compute_magnetic_field(
self, coords, params=None, basis="rpz", source_grid=None, transforms=None
self, coords, params=None, basis="rpz", source_grid=None, transforms=None, perturbations=None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blackfmt] reported by reviewdog 🐶

Suggested change
self, coords, params=None, basis="rpz", source_grid=None, transforms=None, perturbations=None
self,
coords,
params=None,
basis="rpz",
source_grid=None,
transforms=None,
perturbations=None,

@@ -1077,7 +1110,7 @@
is approximately quadratic in the number of coil points.

"""
return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B")
return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B", perturbations)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blackfmt] reported by reviewdog 🐶

Suggested change
return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B", perturbations)
return self._compute_A_or_B(
coords, params, basis, source_grid, transforms, "B", perturbations
)

Comment on lines +1335 to +1338
original_covariance_matrix = jnp.block([
[cof_f_pp(XX, YY), -cov_f_dp(XX, YY)],
[cov_f_dp(XX, YY), -cov_f_dd(XX, YY)],
])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blackfmt] reported by reviewdog 🐶

Suggested change
original_covariance_matrix = jnp.block([
[cof_f_pp(XX, YY), -cov_f_dp(XX, YY)],
[cov_f_dp(XX, YY), -cov_f_dd(XX, YY)],
])
original_covariance_matrix = jnp.block(
[
[cof_f_pp(XX, YY), -cov_f_dp(XX, YY)],
[cov_f_dp(XX, YY), -cov_f_dd(XX, YY)],
]
)

Comment on lines +2291 to +2294
coil_currents = jnp.concatenate([
jnp.atleast_1d(param[idx])
for param, idx in zip(tree_leaves(coil_params), self._indices)
])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blackfmt] reported by reviewdog 🐶

Suggested change
coil_currents = jnp.concatenate([
jnp.atleast_1d(param[idx])
for param, idx in zip(tree_leaves(coil_params), self._indices)
])
coil_currents = jnp.concatenate(
[
jnp.atleast_1d(param[idx])
for param, idx in zip(tree_leaves(coil_params), self._indices)
]
)

Comment on lines +2571 to +2574
self._normalization = np.max([
np.mean(np.abs(Phi)),
1,
])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blackfmt] reported by reviewdog 🐶

Suggested change
self._normalization = np.max([
np.mean(np.abs(Phi)),
1,
])
self._normalization = np.max(
[
np.mean(np.abs(Phi)),
1,
]
)

@@ -1044,7 +1077,7 @@ def _compute_A_or_B(
return AB

def compute_magnetic_field(
self, coords, params=None, basis="rpz", source_grid=None, transforms=None
self, coords, params=None, basis="rpz", source_grid=None, transforms=None, perturbations=None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [flake8] <501> reported by reviewdog 🐶
line too long (101 > 88 characters)

@@ -1077,7 +1110,7 @@
is approximately quadratic in the number of coil points.

"""
return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B")
return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B", perturbations)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [flake8] <501> reported by reviewdog 🐶
line too long (103 > 88 characters)

from desc.backend import jnp, tree_flatten, tree_leaves, tree_map, tree_unflatten
import sympy as sp

from desc.backend import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[flake8] <401> reported by reviewdog 🐶
'desc.backend.fori_loop' imported but unused

@@ -1225,6 +1236,150 @@
return out * constants["mask"]


@dataclasses.dataclass
class StochasticOptimizationSettings:
"""See https://doi.org/10.1088/1741-4326/ac45f3 for implementation details
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[flake8] <205> reported by reviewdog 🐶
1 blank line required between summary line and description

@@ -1225,6 +1236,150 @@
return out * constants["mask"]


@dataclasses.dataclass
class StochasticOptimizationSettings:
"""See https://doi.org/10.1088/1741-4326/ac45f3 for implementation details
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[flake8] <400> reported by reviewdog 🐶
First line should end with a period

default_factory=lambda: jnp.array([])
)

def compute_covariance_matrix(self) -> jnp.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[flake8] <102> reported by reviewdog 🐶
Missing docstring in public method

)

# Construct 2n x 2n covariance matrix:
# K = [[cov_f_pp, cov_f_pd], [cov_f_dp, cov_f_dd]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [flake8] <800> reported by reviewdog 🐶
Found commented out code

return original_covariance_matrix + small_diagonal_matrix

@functools.cached_property
def perturbations(self) -> jnp.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[flake8] <102> reported by reviewdog 🐶
Missing docstring in public method

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants