Skip to content

Commit

Permalink
Merge branch 'master' into rc/kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici authored Nov 27, 2024
2 parents 34b4e57 + 2741269 commit 69574a7
Show file tree
Hide file tree
Showing 26 changed files with 200,845 additions and 79 deletions.
14 changes: 13 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,25 @@ Changelog
=========

New Features
- Add ``from_input_file`` method to ``Equilibrium`` class to generate an ``Equilibrium`` object with boundary, profiles, resolution and flux specified in a given DESC or VMEC input file
- Adds ``from_input_file`` method to ``Equilibrium`` class to generate an ``Equilibrium`` object with boundary, profiles, resolution and flux specified in a given DESC or VMEC input file
- Adds function ``solve_regularized_surface_current`` to ``desc.magnetic_fields`` module that implements the REGCOIL algorithm (Landreman, (2017)) for surface current normal field optimization
* Can specify the tuple ``current_helicity=(M_coil, N_coil)`` to determine if resulting contours correspond to helical topology (both ``(M_coil, N_coil)`` not equal to 0), modular (``N_coil`` equal to 0 and ``M_coil`` nonzero) or windowpane/saddle (``M_coil`` and ``N_coil`` both zero)
* ``M_coil`` is the number of poloidal transits a coil makes before returning to itself, while ``N_coil`` is the number of toroidal transits a coil makes before returning to itself (this is sort of like the QS ``helicity``)
* if multiple values of the regularization parameter are input, will return a family of surface current fields (as a list) corresponding to the solution at each regularization value
- Adds method ``to_CoilSet`` to ``FourierCurrentPotentialField`` which implements a coil cutting algorithm to discretize the surface current into coils
* works for both modular and helical coils
- Adds a new objective ``SurfaceCurrentRegularization`` (which minimizes ``w*|K|``, the regularization term from surface current in the REGCOIL algorithm, with `w` being the objective weight which act as the regularization parameter)
* use of both this and the ``QuadraticFlux`` objective allows for REGCOIL solutions to be obtained through the optimization framework, and combined with other objectives as well.
- Changes local area weighting of Bn in QuadraticFlux objective to be the square root of the local area element (Note that any existing optimizations using this objective may need different weights to achieve the same result now.)
- Adds a new tutorial showing how to use``REGCOIL`` features.


Bug Fixes

- Fixes bug that occurs when taking the gradient of ``root`` and ``root_scalar`` with newer versions of JAX (>=0.4.34) and unpins the JAX version
- Changes ``FixLambdaGauge`` constraint to now enforce zero flux surface average for lambda, instead of enforcing lambda(rho,0,0)=0 as it was incorrectly doing before.
- Fixes bug in ``softmin/softmax`` implementation.


v0.12.3
-------
Expand Down
4 changes: 2 additions & 2 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
from jax.numpy.fft import irfft, rfft, rfft2
from jax.scipy.fft import dct, idct
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
from jax.scipy.special import gammaln, logsumexp
from jax.scipy.special import gammaln
from jax.tree_util import (
register_pytree_node,
tree_flatten,
Expand Down Expand Up @@ -445,7 +445,7 @@ def tangent_solve(g, y):
qr,
solve_triangular,
)
from scipy.special import gammaln, logsumexp # noqa: F401
from scipy.special import gammaln # noqa: F401
from scipy.special import softmax as softargmax # noqa: F401

trapezoid = np.trapezoid if hasattr(np, "trapezoid") else np.trapz
Expand Down
70 changes: 57 additions & 13 deletions desc/compute/_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ def _Phi_z_FourierCurrentPotentialField(params, transforms, profiles, data, **kw
units_long="Amperes",
description="Surface current potential",
dim=1,
params=["params"],
transforms={"grid": [], "potential": []},
params=[],
transforms={"grid": [], "potential": [], "params": []},
profiles=[],
coordinates="tz",
data=[],
Expand All @@ -393,7 +393,7 @@ def _Phi_CurrentPotentialField(params, transforms, profiles, data, **kwargs):
data["Phi"] = transforms["potential"](
transforms["grid"].nodes[:, 1],
transforms["grid"].nodes[:, 2],
**params["params"]
**transforms["params"]
)
return data

Expand All @@ -405,8 +405,8 @@ def _Phi_CurrentPotentialField(params, transforms, profiles, data, **kwargs):
units_long="Amperes",
description="Surface current potential, poloidal derivative",
dim=1,
params=["params"],
transforms={"grid": [], "potential_dtheta": []},
params=[],
transforms={"grid": [], "potential_dtheta": [], "params": []},
profiles=[],
coordinates="tz",
data=[],
Expand All @@ -416,7 +416,7 @@ def _Phi_t_CurrentPotentialField(params, transforms, profiles, data, **kwargs):
data["Phi_t"] = transforms["potential_dtheta"](
transforms["grid"].nodes[:, 1],
transforms["grid"].nodes[:, 2],
**params["params"]
**transforms["params"]
)
return data

Expand All @@ -428,8 +428,8 @@ def _Phi_t_CurrentPotentialField(params, transforms, profiles, data, **kwargs):
units_long="Amperes",
description="Surface current potential, toroidal derivative",
dim=1,
params=["params"],
transforms={"grid": [], "potential_dzeta": []},
params=[],
transforms={"grid": [], "potential_dzeta": [], "params": []},
profiles=[],
coordinates="tz",
data=[],
Expand All @@ -439,11 +439,55 @@ def _Phi_z_CurrentPotentialField(params, transforms, profiles, data, **kwargs):
data["Phi_z"] = transforms["potential_dzeta"](
transforms["grid"].nodes[:, 1],
transforms["grid"].nodes[:, 2],
**params["params"]
**transforms["params"]
)
return data


@register_compute_fun(
name="K^theta",
label="K^{\\theta}",
units="A/m^2",
units_long="Amperes per square meter",
description="Contravariant poloidal component of surface current density",
dim=1,
params=[],
transforms={},
profiles=[],
coordinates="tz",
data=["Phi_z", "|e_theta x e_zeta|"],
parameterization=[
"desc.magnetic_fields._current_potential.CurrentPotentialField",
"desc.magnetic_fields._current_potential.FourierCurrentPotentialField",
],
)
def _K_sup_theta_CurrentPotentialField(params, transforms, profiles, data, **kwargs):
data["K^theta"] = -data["Phi_z"] * (1 / data["|e_theta x e_zeta|"])
return data


@register_compute_fun(
name="K^zeta",
label="K^{\\zeta}",
units="A/m^2",
units_long="Amperes per square meter",
description="Contravariant toroidal component of surface current density",
dim=1,
params=[],
transforms={},
profiles=[],
coordinates="tz",
data=["Phi_t", "|e_theta x e_zeta|"],
parameterization=[
"desc.magnetic_fields._current_potential.CurrentPotentialField",
"desc.magnetic_fields._current_potential.FourierCurrentPotentialField",
],
)
def _K_sup_zeta_CurrentPotentialField(params, transforms, profiles, data, **kwargs):
data["K^zeta"] = data["Phi_t"] * (1 / data["|e_theta x e_zeta|"])
return data


@register_compute_fun(
name="K",
label="\\mathbf{K}",
Expand All @@ -456,16 +500,16 @@ def _Phi_z_CurrentPotentialField(params, transforms, profiles, data, **kwargs):
transforms={},
profiles=[],
coordinates="tz",
data=["Phi_t", "Phi_z", "e_theta", "e_zeta", "|e_theta x e_zeta|"],
data=["K^theta", "K^zeta", "e_theta", "e_zeta"],
parameterization=[
"desc.magnetic_fields._current_potential.CurrentPotentialField",
"desc.magnetic_fields._current_potential.FourierCurrentPotentialField",
],
)
def _K_CurrentPotentialField(params, transforms, profiles, data, **kwargs):
data["K"] = (
data["Phi_t"] * (1 / data["|e_theta x e_zeta|"]) * data["e_zeta"].T
).T - (data["Phi_z"] * (1 / data["|e_theta x e_zeta|"]) * data["e_theta"].T).T
data["K"] = (data["K^zeta"] * data["e_zeta"].T).T + (
data["K^theta"] * data["e_theta"].T
).T
return data


Expand Down
6 changes: 5 additions & 1 deletion desc/magnetic_fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,9 @@
field_line_integrate,
read_BNORM_file,
)
from ._current_potential import CurrentPotentialField, FourierCurrentPotentialField
from ._current_potential import (
CurrentPotentialField,
FourierCurrentPotentialField,
solve_regularized_surface_current,
)
from ._dommaschk import DommaschkPotentialField, dommaschk_potential
6 changes: 4 additions & 2 deletions desc/magnetic_fields/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,11 @@ def compute_Bnormal(
Returns
-------
Bnorm : ndarray
The normal magnetic field to the surface given, of size grid.num_nodes.
The normal magnetic field to the surface given, as an array of
size ``grid.num_nodes``.
coords: ndarray
the locations (in specified basis) at which the Bnormal was calculated
the locations (in specified basis) at which the Bnormal was calculated,
given as a ``(grid.num_nodes , 3)`` shaped array.
"""
calc_Bplasma = False
Expand Down
Loading

0 comments on commit 69574a7

Please sign in to comment.