-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Robust Coil Optimization (Wechsung) #1520
base: master
Are you sure you want to change the base?
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_midres | -0.46 +/- 5.83 | -3.04e-03 +/- 3.84e-02 | 6.56e-01 +/- 2.7e-02 | 6.59e-01 +/- 2.8e-02 |
test_build_transform_fft_highres | -0.43 +/- 5.94 | -4.09e-03 +/- 5.66e-02 | 9.48e-01 +/- 5.0e-02 | 9.52e-01 +/- 2.6e-02 |
test_equilibrium_init_lowres | -3.66 +/- 4.19 | -1.58e-01 +/- 1.81e-01 | 4.15e+00 +/- 1.2e-01 | 4.30e+00 +/- 1.3e-01 |
test_objective_compile_atf | -1.80 +/- 1.78 | -1.57e-01 +/- 1.56e-01 | 8.58e+00 +/- 1.1e-01 | 8.74e+00 +/- 1.1e-01 |
test_objective_compute_atf | -8.66 +/- 4.21 | -1.57e-03 +/- 7.64e-04 | 1.66e-02 +/- 3.5e-04 | 1.82e-02 +/- 6.8e-04 |
test_objective_jac_atf | -0.25 +/- 1.85 | -4.99e-03 +/- 3.70e-02 | 2.00e+00 +/- 3.2e-02 | 2.00e+00 +/- 1.9e-02 |
test_perturb_1 | -4.37 +/- 3.93 | -7.13e-01 +/- 6.42e-01 | 1.56e+01 +/- 5.0e-01 | 1.63e+01 +/- 4.0e-01 |
test_proximal_jac_atf | +0.27 +/- 1.53 | +2.19e-02 +/- 1.25e-01 | 8.18e+00 +/- 1.1e-01 | 8.16e+00 +/- 5.2e-02 |
test_proximal_freeb_compute | +1.39 +/- 1.46 | +2.56e-03 +/- 2.68e-03 | 1.87e-01 +/- 2.1e-03 | 1.84e-01 +/- 1.7e-03 |
test_solve_fixed_iter_compiled | -0.62 +/- 1.41 | -1.30e-01 +/- 2.95e-01 | 2.08e+01 +/- 4.8e-02 | 2.10e+01 +/- 2.9e-01 |
test_objective_compute_ripple | -0.92 +/- 2.28 | -5.98e-03 +/- 1.49e-02 | 6.47e-01 +/- 1.4e-02 | 6.53e-01 +/- 5.9e-03 |
test_objective_grad_ripple | -1.25 +/- 1.73 | -3.47e-02 +/- 4.79e-02 | 2.74e+00 +/- 3.7e-02 | 2.78e+00 +/- 3.0e-02 |
test_build_transform_fft_lowres | -2.75 +/- 3.46 | -1.88e-02 +/- 2.36e-02 | 6.64e-01 +/- 1.6e-02 | 6.82e-01 +/- 1.7e-02 |
test_equilibrium_init_medres | -3.02 +/- 3.68 | -1.48e-01 +/- 1.80e-01 | 4.74e+00 +/- 1.7e-01 | 4.89e+00 +/- 7.2e-02 |
test_equilibrium_init_highres | +0.36 +/- 2.86 | +2.06e-02 +/- 1.63e-01 | 5.71e+00 +/- 1.2e-01 | 5.69e+00 +/- 1.1e-01 |
test_objective_compile_dshape_current | +1.88 +/- 5.93 | +8.15e-02 +/- 2.58e-01 | 4.43e+00 +/- 1.7e-01 | 4.34e+00 +/- 1.9e-01 |
test_objective_compute_dshape_current | +1.34 +/- 2.69 | +7.36e-05 +/- 1.48e-04 | 5.57e-03 +/- 8.2e-05 | 5.50e-03 +/- 1.2e-04 |
test_objective_jac_dshape_current | +2.03 +/- 7.84 | +9.15e-04 +/- 3.53e-03 | 4.59e-02 +/- 2.5e-03 | 4.50e-02 +/- 2.5e-03 |
test_perturb_2 | -1.45 +/- 2.22 | -3.19e-01 +/- 4.88e-01 | 2.17e+01 +/- 1.8e-01 | 2.20e+01 +/- 4.5e-01 |
test_proximal_freeb_jac | -1.19 +/- 2.86 | -8.71e-02 +/- 2.09e-01 | 7.21e+00 +/- 6.5e-02 | 7.29e+00 +/- 2.0e-01 |
test_solve_fixed_iter | +1.09 +/- 5.58 | +3.75e-01 +/- 1.92e+00 | 3.49e+01 +/- 1.0e+00 | 3.45e+01 +/- 1.6e+00 |
test_LinearConstraintProjection_build | +0.41 +/- 2.50 | +4.62e-02 +/- 2.81e-01 | 1.13e+01 +/- 2.3e-01 | 1.12e+01 +/- 1.6e-01 |
test_objective_compute_ripple_spline | +0.44 +/- 1.09 | +1.36e-03 +/- 3.36e-03 | 3.10e-01 +/- 1.4e-03 | 3.09e-01 +/- 3.0e-03 |
test_objective_grad_ripple_spline | +0.18 +/- 2.17 | +2.92e-03 +/- 3.45e-02 | 1.59e+00 +/- 2.4e-02 | 1.59e+00 +/- 2.4e-02 | |
64a7842
to
bbb2072
Compare
desc/io/optimizable_io.py
Outdated
@@ -86,7 +86,10 @@ def _unjittable(x): | |||
if isinstance(x, dict): | |||
return all([_unjittable(y) or y is None for y in x.values()]) | |||
if hasattr(x, "dtype") and np.ndim(x) == 0: | |||
return np.issubdtype(x.dtype, np.bool_) or np.issubdtype(x.dtype, np.int_) | |||
try: | |||
return np.issubdtype(x.dtype, np.bool_) or np.issubdtype(x.dtype, np.int_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this change was, if I recall correctly, to handle setting the key
in stochastic
dict to be static, I dont remember exactly the issue but I think that whatever type of object the JAX key is does not have a .dtype
attribute, maybe
unscented transform as an option? |
Is |
4c18677
to
2625355
Compare
@@ -1044,7 +1077,7 @@ def _compute_A_or_B( | |||
return AB | |||
|
|||
def compute_magnetic_field( | |||
self, coords, params=None, basis="rpz", source_grid=None, transforms=None | |||
self, coords, params=None, basis="rpz", source_grid=None, transforms=None, perturbations=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[blackfmt] reported by reviewdog 🐶
self, coords, params=None, basis="rpz", source_grid=None, transforms=None, perturbations=None | |
self, | |
coords, | |
params=None, | |
basis="rpz", | |
source_grid=None, | |
transforms=None, | |
perturbations=None, |
@@ -1077,7 +1110,7 @@ | |||
is approximately quadratic in the number of coil points. | |||
|
|||
""" | |||
return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") | |||
return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B", perturbations) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[blackfmt] reported by reviewdog 🐶
return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B", perturbations) | |
return self._compute_A_or_B( | |
coords, params, basis, source_grid, transforms, "B", perturbations | |
) |
original_covariance_matrix = jnp.block([ | ||
[cof_f_pp(XX, YY), -cov_f_dp(XX, YY)], | ||
[cov_f_dp(XX, YY), -cov_f_dd(XX, YY)], | ||
]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[blackfmt] reported by reviewdog 🐶
original_covariance_matrix = jnp.block([ | |
[cof_f_pp(XX, YY), -cov_f_dp(XX, YY)], | |
[cov_f_dp(XX, YY), -cov_f_dd(XX, YY)], | |
]) | |
original_covariance_matrix = jnp.block( | |
[ | |
[cof_f_pp(XX, YY), -cov_f_dp(XX, YY)], | |
[cov_f_dp(XX, YY), -cov_f_dd(XX, YY)], | |
] | |
) |
coil_currents = jnp.concatenate([ | ||
jnp.atleast_1d(param[idx]) | ||
for param, idx in zip(tree_leaves(coil_params), self._indices) | ||
]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[blackfmt] reported by reviewdog 🐶
coil_currents = jnp.concatenate([ | |
jnp.atleast_1d(param[idx]) | |
for param, idx in zip(tree_leaves(coil_params), self._indices) | |
]) | |
coil_currents = jnp.concatenate( | |
[ | |
jnp.atleast_1d(param[idx]) | |
for param, idx in zip(tree_leaves(coil_params), self._indices) | |
] | |
) |
self._normalization = np.max([ | ||
np.mean(np.abs(Phi)), | ||
1, | ||
]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[blackfmt] reported by reviewdog 🐶
self._normalization = np.max([ | |
np.mean(np.abs(Phi)), | |
1, | |
]) | |
self._normalization = np.max( | |
[ | |
np.mean(np.abs(Phi)), | |
1, | |
] | |
) |
@@ -1044,7 +1077,7 @@ def _compute_A_or_B( | |||
return AB | |||
|
|||
def compute_magnetic_field( | |||
self, coords, params=None, basis="rpz", source_grid=None, transforms=None | |||
self, coords, params=None, basis="rpz", source_grid=None, transforms=None, perturbations=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚫 [flake8] <501> reported by reviewdog 🐶
line too long (101 > 88 characters)
@@ -1077,7 +1110,7 @@ | |||
is approximately quadratic in the number of coil points. | |||
|
|||
""" | |||
return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") | |||
return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B", perturbations) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚫 [flake8] <501> reported by reviewdog 🐶
line too long (103 > 88 characters)
from desc.backend import jnp, tree_flatten, tree_leaves, tree_map, tree_unflatten | ||
import sympy as sp | ||
|
||
from desc.backend import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[flake8] <401> reported by reviewdog 🐶
'desc.backend.fori_loop' imported but unused
@@ -1225,6 +1236,150 @@ | |||
return out * constants["mask"] | |||
|
|||
|
|||
@dataclasses.dataclass | |||
class StochasticOptimizationSettings: | |||
"""See https://doi.org/10.1088/1741-4326/ac45f3 for implementation details |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[flake8] <205> reported by reviewdog 🐶
1 blank line required between summary line and description
@@ -1225,6 +1236,150 @@ | |||
return out * constants["mask"] | |||
|
|||
|
|||
@dataclasses.dataclass | |||
class StochasticOptimizationSettings: | |||
"""See https://doi.org/10.1088/1741-4326/ac45f3 for implementation details |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[flake8] <400> reported by reviewdog 🐶
First line should end with a period
default_factory=lambda: jnp.array([]) | ||
) | ||
|
||
def compute_covariance_matrix(self) -> jnp.ndarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[flake8] <102> reported by reviewdog 🐶
Missing docstring in public method
) | ||
|
||
# Construct 2n x 2n covariance matrix: | ||
# K = [[cov_f_pp, cov_f_pd], [cov_f_dp, cov_f_dd]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚫 [flake8] <800> reported by reviewdog 🐶
Found commented out code
return original_covariance_matrix + small_diagonal_matrix | ||
|
||
@functools.cached_property | ||
def perturbations(self) -> jnp.ndarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[flake8] <102> reported by reviewdog 🐶
Missing docstring in public method
No description provided.