Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize omnigenity objective over multiple surfaces #1225

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Changelog
New Features
- Add ``from_input_file`` method to ``Equilibrium`` class to generate an ``Equilibrium`` object with boundary, profiles, resolution and flux specified in a given DESC or VMEC input file

Minor Changes

- ``desc.objectives.Omnigenity`` is now vectorized and able to optimize multiple surfaces at the same time. Previously it was required to use a different objective for each surface.

Bug Fixes

Expand Down
73 changes: 57 additions & 16 deletions desc/compute/_omnigenity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
expensive computations.
"""

import functools

from interpax import interp1d

from desc.backend import jnp, sign, vmap
Expand Down Expand Up @@ -509,7 +511,7 @@ def _omni_angle(params, transforms, profiles, data, **kwargs):
description="Boozer poloidal angle",
dim=1,
params=[],
transforms={},
transforms={"grid": []},
profiles=[],
coordinates="rtz",
data=["alpha", "h"],
Expand All @@ -519,9 +521,36 @@ def _omni_angle(params, transforms, profiles, data, **kwargs):
)
def _omni_map_theta_B(params, transforms, profiles, data, **kwargs):
M, N = kwargs.get("helicity", (1, 0))
iota = kwargs.get("iota", 1)
iota = kwargs.get("iota", jnp.ones(transforms["grid"].num_rho))

theta_B, zeta_B = _omnigenity_mapping(
M, N, iota, data["alpha"], data["h"], transforms["grid"]
)
data["theta_B"] = theta_B
data["zeta_B"] = zeta_B
return data


# coordinate mapping matrix from (alpha,h) to (theta_B,zeta_B)
def _omnigenity_mapping(M, N, iota, alpha, h, grid):
iota = jnp.atleast_1d(iota)
assert (
len(iota) == grid.num_rho
), f"got ({len(iota)}) iota values for grid with {grid.num_rho} surfaces"
matrix = jnp.atleast_3d(_omnigenity_mapping_matrix(M, N, iota))
# solve for (theta_B,zeta_B) corresponding to (eta,alpha)
alpha = grid.meshgrid_reshape(alpha, "trz")
h = grid.meshgrid_reshape(h, "trz")
coords = jnp.stack((alpha, h))
# matrix has shape (nr,2,2), coords is shape (2, nt, nr, nz)
# we vectorize the matmul over rho
booz = jnp.einsum("rij,jtrz->itrz", matrix, coords)
theta_B = booz[0].flatten(order="F")
zeta_B = booz[1].flatten(order="F")
return theta_B, zeta_B


@functools.partial(jnp.vectorize, signature="(),(),()->(2,2)")
def _omnigenity_mapping_matrix(M, N, iota):
# need a bunch of wheres to avoid division by zero causing NaN in backward pass
# this is fine since the incorrect values get ignored later, except in OT or OH
# where fieldlines are exactly parallel to |B| contours, but this is a degenerate
Expand All @@ -541,12 +570,7 @@ def _omni_map_theta_B(params, transforms, profiles, data, **kwargs):
mat_OH,
),
)

# solve for (theta_B,zeta_B) corresponding to (eta,alpha)
booz = matrix @ jnp.vstack((data["alpha"], data["h"]))
data["theta_B"] = booz[0, :]
data["zeta_B"] = booz[1, :]
return data
return matrix


@register_compute_fun(
Expand Down Expand Up @@ -575,7 +599,7 @@ def _omni_map_zeta_B(params, transforms, profiles, data, **kwargs):
description="Magnitude of omnigenous magnetic field",
dim=1,
params=["B_lm"],
transforms={"B": [[0, 0, 0]]},
transforms={"grid": [], "B": [[0, 0, 0]]},
profiles=[],
coordinates="rtz",
data=["eta"],
Expand All @@ -584,15 +608,32 @@ def _omni_map_zeta_B(params, transforms, profiles, data, **kwargs):
def _B_omni(params, transforms, profiles, data, **kwargs):
# reshaped to size (L_B, M_B)
B_lm = params["B_lm"].reshape((transforms["B"].basis.L + 1, -1))
# assuming single flux surface, so only take first row (single node)
B_input = vmap(lambda x: transforms["B"].transform(x))(B_lm.T)[:, 0]
B_input = jnp.sort(B_input) # sort to ensure monotonicity
eta_input = jnp.linspace(0, jnp.pi / 2, num=B_input.size)

def _transform(x):
y = transforms["B"].transform(x)
return transforms["grid"].compress(y)

B_input = vmap(_transform)(B_lm.T)
# B_input has shape (num_knots, num_rho)
B_input = jnp.sort(B_input, axis=0) # sort to ensure monotonicity
eta_input = jnp.linspace(0, jnp.pi / 2, num=B_input.shape[0])
eta = transforms["grid"].meshgrid_reshape(data["eta"], "rtz")
eta = eta.reshape((transforms["grid"].num_rho, -1))

def _interp(x, B):
return interp1d(x, eta_input, B, method="monotonic-0")

# |B|_omnigeneous is an even function so B(-eta) = B(+eta) = B(|eta|)
data["|B|"] = interp1d(
jnp.abs(data["eta"]), eta_input, B_input, method="monotonic-0"
B = vmap(_interp)(jnp.abs(eta), B_input.T) # shape (nr, nt*nz)
B = B.reshape(
(
transforms["grid"].num_rho,
transforms["grid"].num_poloidal,
transforms["grid"].num_zeta,
)
)
B = jnp.moveaxis(B, 0, 1)
data["|B|"] = B.flatten(order="F")
return data


Expand Down
88 changes: 48 additions & 40 deletions desc/objectives/_omnigenity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import warnings

from desc.backend import jnp
from desc.backend import jnp, vmap
from desc.compute import get_profiles, get_transforms
from desc.compute._omnigenity import _omnigenity_mapping
from desc.compute.utils import _compute as compute_fun
from desc.grid import LinearGrid
from desc.utils import Timer, errorif, warnif
Expand Down Expand Up @@ -515,13 +516,12 @@
eq_grid : Grid, optional
Collocation grid containing the nodes to evaluate at for equilibrium data.
Defaults to a linearly space grid on the rho=1 surface.
Must be a single flux surface without stellarator symmetry.
Must be without stellarator symmetry.
field_grid : Grid, optional
Collocation grid containing the nodes to evaluate at for omnigenous field data.
The grid nodes are given in the usual (ρ,θ,ζ) coordinates (with θ ∈ [0, 2π),
ζ ∈ [0, 2π/NFP)), but θ is mapped to η and ζ is mapped to α. Defaults to a
linearly space grid on the rho=1 surface. Must be a single flux surface without
stellarator symmetry.
linearly space grid on the rho=1 surface. Must be without stellarator symmetry.
M_booz : int, optional
Poloidal resolution of Boozer transformation. Default = 2 * eq.M.
N_booz : int, optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is probably discussed during the first PR but can we change the docs of eq_fixed at line 536 to say "only the field is allowed to change" or something on that front? The user probably doesn't know what self.things mean

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for line 542

Expand Down Expand Up @@ -633,9 +633,9 @@

# default grids
if self._eq_grid is None and self._field_grid is not None:
rho = self._field_grid.nodes[0, 0]
rho = self._field_grid.nodes[self._field_grid.unique_rho_idx, 0]

Check warning on line 636 in desc/objectives/_omnigenity.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/_omnigenity.py#L636

Added line #L636 was not covered by tests
elif self._eq_grid is not None and self._field_grid is None:
rho = self._eq_grid.nodes[0, 0]
rho = self._eq_grid.nodes[self._eq_grid.unique_rho_idx, 0]
elif self._eq_grid is None and self._field_grid is None:
rho = 1.0
if self._eq_grid is None:
Expand All @@ -661,12 +661,13 @@
)
errorif(eq_grid.sym, msg="eq_grid must not be symmetric")
errorif(field_grid.sym, msg="field_grid must not be symmetric")
errorif(eq_grid.num_rho != 1, msg="eq_grid must be a single surface")
errorif(field_grid.num_rho != 1, msg="field_grid must be a single surface")
field_rho = field_grid.nodes[field_grid.unique_rho_idx, 0]
eq_rho = eq_grid.nodes[eq_grid.unique_rho_idx, 0]
errorif(
eq_grid.nodes[eq_grid.unique_rho_idx, 0]
!= field_grid.nodes[field_grid.unique_rho_idx, 0],
msg="eq_grid and field_grid must be the same surface",
any(eq_rho != field_rho),
msg="eq_grid and field_grid must be the same surface(s), "
+ f"eq_grid has surfaces {eq_rho}, "
+ f"field_grid has surfaces {field_rho}",
)
errorif(
jnp.any(field.B_lm[: field.M_B] < 0),
Expand Down Expand Up @@ -772,6 +773,9 @@
eq_params = params_1
field_params = params_2

eq_grid = constants["eq_transforms"]["grid"]
field_grid = constants["field_transforms"]["grid"]

# compute eq data
if self._eq_fixed:
eq_data = constants["eq_data"]
Expand All @@ -789,27 +793,15 @@
field_data = constants["field_data"]
# update theta_B and zeta_B with new iota from the equilibrium
M, N = constants["helicity"]
iota = jnp.mean(eq_data["iota"])
# see comment in desc.compute._omnigenity for the explanation of these
# wheres
mat_OP = jnp.array(
[[N, iota / jnp.where(N == 0, 1, N)], [0, 1 / jnp.where(N == 0, 1, N)]]
)
mat_OT = jnp.array([[0, -1], [M, -1 / jnp.where(iota == 0, 1.0, iota)]])
den = jnp.where((N - M * iota) == 0, 1.0, (N - M * iota))
mat_OH = jnp.array([[N, M * iota / den], [M, M / den]])
matrix = jnp.where(
M == 0,
mat_OP,
jnp.where(
N == 0,
mat_OT,
mat_OH,
),
iota = eq_data["iota"][eq_grid.unique_rho_idx]
theta_B, zeta_B = _omnigenity_mapping(
M,
N,
iota,
field_data["alpha"],
field_data["h"],
field_grid,
)
booz = matrix @ jnp.vstack((field_data["alpha"], field_data["h"]))
theta_B = booz[0, :]
zeta_B = booz[1, :]
else:
field_data = compute_fun(
"desc.magnetic_fields._core.OmnigenousField",
Expand All @@ -818,22 +810,38 @@
transforms=constants["field_transforms"],
profiles={},
helicity=constants["helicity"],
iota=jnp.mean(eq_data["iota"]),
iota=eq_data["iota"][eq_grid.unique_rho_idx],
)
theta_B = field_data["theta_B"]
zeta_B = field_data["zeta_B"]

# additional computations that cannot be part of the regular compute API
nodes = jnp.vstack(
(
jnp.zeros_like(theta_B),
theta_B,
zeta_B,

def _compute_B_eta_alpha(theta_B, zeta_B, B_mn):
nodes = jnp.vstack(
(
jnp.zeros_like(theta_B),
theta_B,
zeta_B,
)
).T
B_eta_alpha = jnp.matmul(
constants["eq_transforms"]["B"].basis.evaluate(nodes), B_mn
)
).T
B_eta_alpha = jnp.matmul(
constants["eq_transforms"]["B"].basis.evaluate(nodes), eq_data["|B|_mn"]
return B_eta_alpha

theta_B = field_grid.meshgrid_reshape(theta_B, "rtz").reshape(
(field_grid.num_rho, -1)
)
zeta_B = field_grid.meshgrid_reshape(zeta_B, "rtz").reshape(
(field_grid.num_rho, -1)
)
B_mn = eq_data["|B|_mn"].reshape((eq_grid.num_rho, -1))
B_eta_alpha = vmap(_compute_B_eta_alpha)(theta_B, zeta_B, B_mn)
B_eta_alpha = B_eta_alpha.reshape(
(field_grid.num_rho, field_grid.num_theta, field_grid.num_zeta)
)
B_eta_alpha = jnp.moveaxis(B_eta_alpha, 0, 1).flatten(order="F")
omnigenity_error = B_eta_alpha - field_data["|B|"]
weights = (self.eta_weight + 1) / 2 + (self.eta_weight - 1) / 2 * jnp.cos(
field_data["eta"]
Expand Down
52 changes: 52 additions & 0 deletions tests/test_objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,58 @@ def test_signed_plasma_vessel_distance(self):
)
obj.build()

@pytest.mark.unit
def test_omnigenity_multiple_surfaces(self):
"""Test omnigenity transform vectorized over multiple surfaces."""
surf = FourierRZToroidalSurface.from_qp_model(
major_radius=1,
aspect_ratio=20,
elongation=6,
mirror_ratio=0.2,
torsion=0.1,
NFP=1,
sym=True,
)
eq = Equilibrium(
Psi=6e-3,
M=4,
N=4,
surface=surf,
iota=PowerSeriesProfile(1, 0, -1), # ensure diff surfs have diff iota
)
field = OmnigenousField(
L_B=1,
M_B=3,
L_x=1,
M_x=1,
N_x=1,
NFP=eq.NFP,
helicity=(1, 1),
B_lm=np.array(
[
[0.8, 1.0, 1.2],
[-0.4, 0.0, 0.6], # radially varying B
]
).flatten(),
)
grid1 = LinearGrid(rho=0.5, M=eq.M_grid, N=eq.N_grid)
grid2 = LinearGrid(rho=1.0, M=eq.M_grid, N=eq.N_grid)
grid3 = LinearGrid(rho=np.array([0.5, 1.0]), M=eq.M_grid, N=eq.N_grid)
obj1 = Omnigenity(eq=eq, field=field, eq_grid=grid1)
obj2 = Omnigenity(eq=eq, field=field, eq_grid=grid2)
obj3 = Omnigenity(eq=eq, field=field, eq_grid=grid3)
obj1.build()
obj2.build()
obj3.build()
f1 = obj1.compute(*obj1.xs(eq, field))
f2 = obj2.compute(*obj2.xs(eq, field))
f3 = obj3.compute(*obj3.xs(eq, field))
# the order will be different but the values should be the same so we sort
# before comparing
np.testing.assert_allclose(
np.sort(np.concatenate([f1, f2])), np.sort(f3), atol=1e-14
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the order different exactly?

)


@pytest.mark.regression
def test_derivative_modes():
Expand Down
Loading