diff --git a/CHANGELOG.md b/CHANGELOG.md index 6743861fb..2a307ff0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ Changelog New Features - Add ``from_input_file`` method to ``Equilibrium`` class to generate an ``Equilibrium`` object with boundary, profiles, resolution and flux specified in a given DESC or VMEC input file +Minor Changes + +- ``desc.objectives.Omnigenity`` is now vectorized and able to optimize multiple surfaces at the same time. Previously it was required to use a different objective for each surface. Bug Fixes diff --git a/desc/compute/_omnigenity.py b/desc/compute/_omnigenity.py index 2355a6fcb..2be5a7836 100644 --- a/desc/compute/_omnigenity.py +++ b/desc/compute/_omnigenity.py @@ -9,6 +9,8 @@ expensive computations. """ +import functools + from interpax import interp1d from desc.backend import jnp, sign, vmap @@ -509,7 +511,7 @@ def _omni_angle(params, transforms, profiles, data, **kwargs): description="Boozer poloidal angle", dim=1, params=[], - transforms={}, + transforms={"grid": []}, profiles=[], coordinates="rtz", data=["alpha", "h"], @@ -519,9 +521,36 @@ def _omni_angle(params, transforms, profiles, data, **kwargs): ) def _omni_map_theta_B(params, transforms, profiles, data, **kwargs): M, N = kwargs.get("helicity", (1, 0)) - iota = kwargs.get("iota", 1) + iota = kwargs.get("iota", jnp.ones(transforms["grid"].num_rho)) + + theta_B, zeta_B = _omnigenity_mapping( + M, N, iota, data["alpha"], data["h"], transforms["grid"] + ) + data["theta_B"] = theta_B + data["zeta_B"] = zeta_B + return data + - # coordinate mapping matrix from (alpha,h) to (theta_B,zeta_B) +def _omnigenity_mapping(M, N, iota, alpha, h, grid): + iota = jnp.atleast_1d(iota) + assert ( + len(iota) == grid.num_rho + ), f"got ({len(iota)}) iota values for grid with {grid.num_rho} surfaces" + matrix = jnp.atleast_3d(_omnigenity_mapping_matrix(M, N, iota)) + # solve for (theta_B,zeta_B) corresponding to (eta,alpha) + alpha = grid.meshgrid_reshape(alpha, "trz") + h = grid.meshgrid_reshape(h, "trz") + coords = jnp.stack((alpha, h)) + # matrix has shape (nr,2,2), coords is shape (2, nt, nr, nz) + # we vectorize the matmul over rho + booz = jnp.einsum("rij,jtrz->itrz", matrix, coords) + theta_B = booz[0].flatten(order="F") + zeta_B = booz[1].flatten(order="F") + return theta_B, zeta_B + + +@functools.partial(jnp.vectorize, signature="(),(),()->(2,2)") +def _omnigenity_mapping_matrix(M, N, iota): # need a bunch of wheres to avoid division by zero causing NaN in backward pass # this is fine since the incorrect values get ignored later, except in OT or OH # where fieldlines are exactly parallel to |B| contours, but this is a degenerate @@ -541,12 +570,7 @@ def _omni_map_theta_B(params, transforms, profiles, data, **kwargs): mat_OH, ), ) - - # solve for (theta_B,zeta_B) corresponding to (eta,alpha) - booz = matrix @ jnp.vstack((data["alpha"], data["h"])) - data["theta_B"] = booz[0, :] - data["zeta_B"] = booz[1, :] - return data + return matrix @register_compute_fun( @@ -575,7 +599,7 @@ def _omni_map_zeta_B(params, transforms, profiles, data, **kwargs): description="Magnitude of omnigenous magnetic field", dim=1, params=["B_lm"], - transforms={"B": [[0, 0, 0]]}, + transforms={"grid": [], "B": [[0, 0, 0]]}, profiles=[], coordinates="rtz", data=["eta"], @@ -584,15 +608,32 @@ def _omni_map_zeta_B(params, transforms, profiles, data, **kwargs): def _B_omni(params, transforms, profiles, data, **kwargs): # reshaped to size (L_B, M_B) B_lm = params["B_lm"].reshape((transforms["B"].basis.L + 1, -1)) - # assuming single flux surface, so only take first row (single node) - B_input = vmap(lambda x: transforms["B"].transform(x))(B_lm.T)[:, 0] - B_input = jnp.sort(B_input) # sort to ensure monotonicity - eta_input = jnp.linspace(0, jnp.pi / 2, num=B_input.size) + + def _transform(x): + y = transforms["B"].transform(x) + return transforms["grid"].compress(y) + + B_input = vmap(_transform)(B_lm.T) + # B_input has shape (num_knots, num_rho) + B_input = jnp.sort(B_input, axis=0) # sort to ensure monotonicity + eta_input = jnp.linspace(0, jnp.pi / 2, num=B_input.shape[0]) + eta = transforms["grid"].meshgrid_reshape(data["eta"], "rtz") + eta = eta.reshape((transforms["grid"].num_rho, -1)) + + def _interp(x, B): + return interp1d(x, eta_input, B, method="monotonic-0") # |B|_omnigeneous is an even function so B(-eta) = B(+eta) = B(|eta|) - data["|B|"] = interp1d( - jnp.abs(data["eta"]), eta_input, B_input, method="monotonic-0" + B = vmap(_interp)(jnp.abs(eta), B_input.T) # shape (nr, nt*nz) + B = B.reshape( + ( + transforms["grid"].num_rho, + transforms["grid"].num_poloidal, + transforms["grid"].num_zeta, + ) ) + B = jnp.moveaxis(B, 0, 1) + data["|B|"] = B.flatten(order="F") return data diff --git a/desc/objectives/_omnigenity.py b/desc/objectives/_omnigenity.py index 8ac2faa9e..2f54d51c1 100644 --- a/desc/objectives/_omnigenity.py +++ b/desc/objectives/_omnigenity.py @@ -2,8 +2,9 @@ import warnings -from desc.backend import jnp +from desc.backend import jnp, vmap from desc.compute import get_profiles, get_transforms +from desc.compute._omnigenity import _omnigenity_mapping from desc.compute.utils import _compute as compute_fun from desc.grid import LinearGrid from desc.utils import Timer, errorif, warnif @@ -515,13 +516,12 @@ class Omnigenity(_Objective): eq_grid : Grid, optional Collocation grid containing the nodes to evaluate at for equilibrium data. Defaults to a linearly space grid on the rho=1 surface. - Must be a single flux surface without stellarator symmetry. + Must be without stellarator symmetry. field_grid : Grid, optional Collocation grid containing the nodes to evaluate at for omnigenous field data. The grid nodes are given in the usual (ρ,θ,ζ) coordinates (with θ ∈ [0, 2π), ζ ∈ [0, 2π/NFP)), but θ is mapped to η and ζ is mapped to α. Defaults to a - linearly space grid on the rho=1 surface. Must be a single flux surface without - stellarator symmetry. + linearly space grid on the rho=1 surface. Must be without stellarator symmetry. M_booz : int, optional Poloidal resolution of Boozer transformation. Default = 2 * eq.M. N_booz : int, optional @@ -633,9 +633,9 @@ def build(self, use_jit=True, verbose=1): # default grids if self._eq_grid is None and self._field_grid is not None: - rho = self._field_grid.nodes[0, 0] + rho = self._field_grid.nodes[self._field_grid.unique_rho_idx, 0] elif self._eq_grid is not None and self._field_grid is None: - rho = self._eq_grid.nodes[0, 0] + rho = self._eq_grid.nodes[self._eq_grid.unique_rho_idx, 0] elif self._eq_grid is None and self._field_grid is None: rho = 1.0 if self._eq_grid is None: @@ -661,12 +661,13 @@ def build(self, use_jit=True, verbose=1): ) errorif(eq_grid.sym, msg="eq_grid must not be symmetric") errorif(field_grid.sym, msg="field_grid must not be symmetric") - errorif(eq_grid.num_rho != 1, msg="eq_grid must be a single surface") - errorif(field_grid.num_rho != 1, msg="field_grid must be a single surface") + field_rho = field_grid.nodes[field_grid.unique_rho_idx, 0] + eq_rho = eq_grid.nodes[eq_grid.unique_rho_idx, 0] errorif( - eq_grid.nodes[eq_grid.unique_rho_idx, 0] - != field_grid.nodes[field_grid.unique_rho_idx, 0], - msg="eq_grid and field_grid must be the same surface", + any(eq_rho != field_rho), + msg="eq_grid and field_grid must be the same surface(s), " + + f"eq_grid has surfaces {eq_rho}, " + + f"field_grid has surfaces {field_rho}", ) errorif( jnp.any(field.B_lm[: field.M_B] < 0), @@ -772,6 +773,9 @@ def compute(self, params_1=None, params_2=None, constants=None): eq_params = params_1 field_params = params_2 + eq_grid = constants["eq_transforms"]["grid"] + field_grid = constants["field_transforms"]["grid"] + # compute eq data if self._eq_fixed: eq_data = constants["eq_data"] @@ -789,27 +793,15 @@ def compute(self, params_1=None, params_2=None, constants=None): field_data = constants["field_data"] # update theta_B and zeta_B with new iota from the equilibrium M, N = constants["helicity"] - iota = jnp.mean(eq_data["iota"]) - # see comment in desc.compute._omnigenity for the explanation of these - # wheres - mat_OP = jnp.array( - [[N, iota / jnp.where(N == 0, 1, N)], [0, 1 / jnp.where(N == 0, 1, N)]] - ) - mat_OT = jnp.array([[0, -1], [M, -1 / jnp.where(iota == 0, 1.0, iota)]]) - den = jnp.where((N - M * iota) == 0, 1.0, (N - M * iota)) - mat_OH = jnp.array([[N, M * iota / den], [M, M / den]]) - matrix = jnp.where( - M == 0, - mat_OP, - jnp.where( - N == 0, - mat_OT, - mat_OH, - ), + iota = eq_data["iota"][eq_grid.unique_rho_idx] + theta_B, zeta_B = _omnigenity_mapping( + M, + N, + iota, + field_data["alpha"], + field_data["h"], + field_grid, ) - booz = matrix @ jnp.vstack((field_data["alpha"], field_data["h"])) - theta_B = booz[0, :] - zeta_B = booz[1, :] else: field_data = compute_fun( "desc.magnetic_fields._core.OmnigenousField", @@ -818,22 +810,38 @@ def compute(self, params_1=None, params_2=None, constants=None): transforms=constants["field_transforms"], profiles={}, helicity=constants["helicity"], - iota=jnp.mean(eq_data["iota"]), + iota=eq_data["iota"][eq_grid.unique_rho_idx], ) theta_B = field_data["theta_B"] zeta_B = field_data["zeta_B"] # additional computations that cannot be part of the regular compute API - nodes = jnp.vstack( - ( - jnp.zeros_like(theta_B), - theta_B, - zeta_B, + + def _compute_B_eta_alpha(theta_B, zeta_B, B_mn): + nodes = jnp.vstack( + ( + jnp.zeros_like(theta_B), + theta_B, + zeta_B, + ) + ).T + B_eta_alpha = jnp.matmul( + constants["eq_transforms"]["B"].basis.evaluate(nodes), B_mn ) - ).T - B_eta_alpha = jnp.matmul( - constants["eq_transforms"]["B"].basis.evaluate(nodes), eq_data["|B|_mn"] + return B_eta_alpha + + theta_B = field_grid.meshgrid_reshape(theta_B, "rtz").reshape( + (field_grid.num_rho, -1) + ) + zeta_B = field_grid.meshgrid_reshape(zeta_B, "rtz").reshape( + (field_grid.num_rho, -1) + ) + B_mn = eq_data["|B|_mn"].reshape((eq_grid.num_rho, -1)) + B_eta_alpha = vmap(_compute_B_eta_alpha)(theta_B, zeta_B, B_mn) + B_eta_alpha = B_eta_alpha.reshape( + (field_grid.num_rho, field_grid.num_theta, field_grid.num_zeta) ) + B_eta_alpha = jnp.moveaxis(B_eta_alpha, 0, 1).flatten(order="F") omnigenity_error = B_eta_alpha - field_data["|B|"] weights = (self.eta_weight + 1) / 2 + (self.eta_weight - 1) / 2 * jnp.cos( field_data["eta"] diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index c1461b45c..acfcb13cb 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -1348,6 +1348,58 @@ def test_signed_plasma_vessel_distance(self): ) obj.build() + @pytest.mark.unit + def test_omnigenity_multiple_surfaces(self): + """Test omnigenity transform vectorized over multiple surfaces.""" + surf = FourierRZToroidalSurface.from_qp_model( + major_radius=1, + aspect_ratio=20, + elongation=6, + mirror_ratio=0.2, + torsion=0.1, + NFP=1, + sym=True, + ) + eq = Equilibrium( + Psi=6e-3, + M=4, + N=4, + surface=surf, + iota=PowerSeriesProfile(1, 0, -1), # ensure diff surfs have diff iota + ) + field = OmnigenousField( + L_B=1, + M_B=3, + L_x=1, + M_x=1, + N_x=1, + NFP=eq.NFP, + helicity=(1, 1), + B_lm=np.array( + [ + [0.8, 1.0, 1.2], + [-0.4, 0.0, 0.6], # radially varying B + ] + ).flatten(), + ) + grid1 = LinearGrid(rho=0.5, M=eq.M_grid, N=eq.N_grid) + grid2 = LinearGrid(rho=1.0, M=eq.M_grid, N=eq.N_grid) + grid3 = LinearGrid(rho=np.array([0.5, 1.0]), M=eq.M_grid, N=eq.N_grid) + obj1 = Omnigenity(eq=eq, field=field, eq_grid=grid1) + obj2 = Omnigenity(eq=eq, field=field, eq_grid=grid2) + obj3 = Omnigenity(eq=eq, field=field, eq_grid=grid3) + obj1.build() + obj2.build() + obj3.build() + f1 = obj1.compute(*obj1.xs(eq, field)) + f2 = obj2.compute(*obj2.xs(eq, field)) + f3 = obj3.compute(*obj3.xs(eq, field)) + # the order will be different but the values should be the same so we sort + # before comparing + np.testing.assert_allclose( + np.sort(np.concatenate([f1, f2])), np.sort(f3), atol=1e-14 + ) + @pytest.mark.regression def test_derivative_modes():