From feb50154ba70b6ba963246c38556991a2035e906 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 18 Dec 2023 15:14:35 +0100 Subject: [PATCH] FEAT: support class attributes in `unevaluated_expression` (#375) * BREAK: switch arguments of `BlattWeisskopfSquared` * MAINT: rewrite `BlattWeisskopfSquared` with decorator * MAINT: test ClassVar of `BlattWeisskopfSquared` --- docs/_extend_docstrings.py | 4 ++-- docs/usage/dynamics.ipynb | 4 ++-- src/ampform/dynamics/__init__.py | 36 ++++++++++++++------------------ src/ampform/sympy/_decorator.py | 7 ++++++- tests/dynamics/test_dynamics.py | 14 ++++++++++++- tests/dynamics/test_sympy.py | 2 +- tests/sympy/test_caching.py | 12 +++++------ tests/sympy/test_decorator.py | 23 +++++++++++++++++++- 8 files changed, 68 insertions(+), 34 deletions(-) diff --git a/docs/_extend_docstrings.py b/docs/_extend_docstrings.py index f3cb5ecb5..e7e261816 100644 --- a/docs/_extend_docstrings.py +++ b/docs/_extend_docstrings.py @@ -65,9 +65,9 @@ def extend_docstrings() -> None: def extend_BlattWeisskopfSquared() -> None: from ampform.dynamics import BlattWeisskopfSquared - L = sp.Symbol("L", integer=True) z = sp.Symbol("z", real=True) - expr = BlattWeisskopfSquared(L, z) + L = sp.Symbol("L", integer=True) + expr = BlattWeisskopfSquared(z, angular_momentum=L) _append_latex_doit_definition(expr, deep=True, full_width=True) diff --git a/docs/usage/dynamics.ipynb b/docs/usage/dynamics.ipynb index 34fd56954..806efacb8 100644 --- a/docs/usage/dynamics.ipynb +++ b/docs/usage/dynamics.ipynb @@ -159,7 +159,7 @@ "\n", "L = sp.Symbol(\"L\", integer=True)\n", "z = sp.Symbol(\"z\", real=True)\n", - "ff2 = BlattWeisskopfSquared(L, z)\n", + "ff2 = BlattWeisskopfSquared(z, L)\n", "Math(sp.multiline_latex(ff2, ff2.doit(), environment=\"eqnarray\"))" ] }, @@ -183,7 +183,7 @@ "m, m_a, m_b, d = sp.symbols(\"m, m_a, m_b, d\")\n", "s = m**2\n", "q_squared = BreakupMomentumSquared(s, m_a, m_b)\n", - "ff2 = BlattWeisskopfSquared(L, z=q_squared * d**2)" + "ff2 = BlattWeisskopfSquared(q_squared * d**2, angular_momentum=L)" ] }, { diff --git a/src/ampform/dynamics/__init__.py b/src/ampform/dynamics/__init__.py index 41883b159..c91e88a2b 100644 --- a/src/ampform/dynamics/__init__.py +++ b/src/ampform/dynamics/__init__.py @@ -6,7 +6,7 @@ # cspell:ignore asner mhash from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, ClassVar import sympy as sp from sympy.core.basic import _aresame @@ -25,26 +25,26 @@ ) from ampform.sympy import ( UnevaluatedExpression, - create_expression, implement_doit_method, + unevaluated_expression, ) if TYPE_CHECKING: from sympy.printing.latex import LatexPrinter -@implement_doit_method -class BlattWeisskopfSquared(UnevaluatedExpression): +@unevaluated_expression +class BlattWeisskopfSquared(sp.Expr): # cspell:ignore pychyGekoppeltePartialwellenanalyseAnnihilationen r"""Blatt-Weisskopf function :math:`B_L^2(z)`, up to :math:`L \leq 8`. Args: - angular_momentum: Angular momentum :math:`L` of the decaying particle. - z: Argument of the Blatt-Weisskopf function :math:`B_L^2(z)`. A usual choice is :math:`z = (d q)^2` with :math:`d` the impact parameter and :math:`q` the breakup-momentum (see `.BreakupMomentumSquared`). + angular_momentum: Angular momentum :math:`L` of the decaying particle. + Note that equal powers of :math:`z` appear in the nominator and the denominator, while some sources have nominator :math:`1`, instead of :math:`z^L`. Compare for instance Equation (50.27) in :pdg-review:`2021; Resonances; p.9`. @@ -57,20 +57,20 @@ class BlattWeisskopfSquared(UnevaluatedExpression): See also :ref:`usage/dynamics:Form factor`. """ - is_commutative = True - max_angular_momentum: int | None = None + z: Any + angular_momentum: Any + _latex_repr_ = R"B_{{{angular_momentum}}}^2\left({z}\right)" + + max_angular_momentum: ClassVar[int | None] = None """Limit the maximum allowed angular momentum :math:`L`. This improves performance when :math:`L` is a `~sympy.core.symbol.Symbol` and you are note interested in higher angular momenta. """ - def __new__(cls, angular_momentum, z, **hints) -> BlattWeisskopfSquared: - return create_expression(cls, angular_momentum, z, **hints) - def evaluate(self) -> sp.Expr: - angular_momentum: sp.Expr = self.args[0] # type: ignore[assignment] - z: sp.Expr = self.args[1] # type: ignore[assignment] + z: sp.Expr = self.args[0] # type: ignore[assignment] + angular_momentum: sp.Expr = self.args[1] # type: ignore[assignment] cases: dict[int, sp.Expr] = { 0: sp.S.One, 1: 2 * z / (z + 1), @@ -138,10 +138,6 @@ def evaluate(self) -> sp.Expr: if self.max_angular_momentum is None or value <= self.max_angular_momentum ]) - def _latex(self, printer: LatexPrinter, *args) -> str: - angular_momentum, z = tuple(map(printer._print, self.args)) - return Rf"B_{{{angular_momentum}}}^2\left({z}\right)" - @implement_doit_method class EnergyDependentWidth(UnevaluatedExpression): @@ -208,12 +204,12 @@ def evaluate(self) -> sp.Expr: q_squared = BreakupMomentumSquared(s, m_a, m_b) q0_squared = BreakupMomentumSquared(mass0**2, m_a, m_b) # type: ignore[operator] form_factor_sq = BlattWeisskopfSquared( + q_squared * meson_radius**2, # type: ignore[operator] angular_momentum, - z=q_squared * meson_radius**2, # type: ignore[operator] ) form_factor0_sq = BlattWeisskopfSquared( + q0_squared * meson_radius**2, # type: ignore[operator] angular_momentum, - z=q0_squared * meson_radius**2, # type: ignore[operator] ) rho = self.phsp_factor(s, m_a, m_b) rho0 = self.phsp_factor(mass0**2, m_a, m_b) # type: ignore[operator] @@ -303,5 +299,5 @@ def formulate_form_factor(s, m_a, m_b, angular_momentum, meson_radius) -> sp.Exp `~sympy.functions.elementary.miscellaneous.sqrt` of a `.BlattWeisskopfSquared`. """ q_squared = BreakupMomentumSquared(s, m_a, m_b) - ff_squared = BlattWeisskopfSquared(angular_momentum, z=q_squared * meson_radius**2) + ff_squared = BlattWeisskopfSquared(q_squared * meson_radius**2, angular_momentum) return sp.sqrt(ff_squared) diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index 1a03edb9d..e8501b84e 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -256,6 +256,7 @@ def _get_attribute_names(cls: type) -> tuple[str, ...]: ... a: int ... b: int ... _c: int + ... n: ClassVar[int] = 2 ... ... def print(self): ... ... @@ -263,7 +264,11 @@ def _get_attribute_names(cls: type) -> tuple[str, ...]: ('a', 'b') """ return tuple( - k for k in cls.__annotations__ if not callable(k) if not k.startswith("_") + k + for k, v in cls.__annotations__.items() + if not callable(k) + if not k.startswith("_") + if not str(v).startswith("ClassVar") ) diff --git a/tests/dynamics/test_dynamics.py b/tests/dynamics/test_dynamics.py index b5b134908..c1282700c 100644 --- a/tests/dynamics/test_dynamics.py +++ b/tests/dynamics/test_dynamics.py @@ -24,7 +24,7 @@ class TestBlattWeisskopfSquared: def test_max_angular_momentum(self): z = sp.Symbol("z") angular_momentum = sp.Symbol("L", integer=True) - form_factor = BlattWeisskopfSquared(angular_momentum, z=z) + form_factor = BlattWeisskopfSquared(z, angular_momentum) form_factor_9 = form_factor.subs(angular_momentum, 8).evaluate() factor, z_power, _ = form_factor_9.args assert factor == 4392846440677 @@ -35,6 +35,18 @@ def test_max_angular_momentum(self): (1, sp.Eq(angular_momentum, 0)), (2 * z / (z + 1), sp.Eq(angular_momentum, 1)), ) + BlattWeisskopfSquared.max_angular_momentum = None + + def test_unevaluated_expression(self): + z = sp.Symbol("z") + ff1 = BlattWeisskopfSquared(z, angular_momentum=1) + ff2 = BlattWeisskopfSquared(z, angular_momentum=2) + assert ff1.max_angular_momentum is None + assert ff2.max_angular_momentum is None + BlattWeisskopfSquared.max_angular_momentum = 3 + assert ff1.max_angular_momentum is 3 # noqa: F632 + assert ff2.max_angular_momentum is 3 # noqa: F632 + BlattWeisskopfSquared.max_angular_momentum = None class TestEnergyDependentWidth: diff --git a/tests/dynamics/test_sympy.py b/tests/dynamics/test_sympy.py index 24fbb59b5..d0123b999 100644 --- a/tests/dynamics/test_sympy.py +++ b/tests/dynamics/test_sympy.py @@ -28,7 +28,7 @@ def test_pickle(): assert expr == imported_expr # Pickle classes derived from UnevaluatedExpression - expr = BlattWeisskopfSquared(angular_momentum, z=z) + expr = BlattWeisskopfSquared(z, angular_momentum) pickled_obj = pickle.dumps(expr) imported_expr = pickle.loads(pickled_obj) # noqa: S301 assert expr == imported_expr diff --git a/tests/sympy/test_caching.py b/tests/sympy/test_caching.py index 55a033e38..c9f2eebf5 100644 --- a/tests/sympy/test_caching.py +++ b/tests/sympy/test_caching.py @@ -78,17 +78,17 @@ def test_get_readable_hash_large(amplitude_model: tuple[str, HelicityModel]): # https://github.com/ComPWA/ampform/actions/runs/3277058875/jobs/5393849802 # https://github.com/ComPWA/ampform/actions/runs/3277143883/jobs/5394043014 expected_hash = { - "canonical-helicity": "pythonhashseed-0-6040455869260657745", - "helicity": "pythonhashseed-0-1928646339459384503", + "canonical-helicity": "pythonhashseed-0-3873186712292274641", + "helicity": "pythonhashseed-0-8800154542426799839", }[formalism] elif sys.version_info >= (3, 11): expected_hash = { - "canonical-helicity": "pythonhashseed-0+409069872540431022", - "helicity": "pythonhashseed-0-8907705932662936900", + "canonical-helicity": "pythonhashseed-0+4035132515642199515", + "helicity": "pythonhashseed-0-2843057473565885663", }[formalism] else: expected_hash = { - "canonical-helicity": "pythonhashseed-0-7143983882032045549", - "helicity": "pythonhashseed-0+3357246175053927117", + "canonical-helicity": "pythonhashseed-0+3420919389670627445", + "helicity": "pythonhashseed-0-6681863313351758450", }[formalism] assert get_readable_hash(model.expression) == expected_hash diff --git a/tests/sympy/test_decorator.py b/tests/sympy/test_decorator.py index 40bf1d50a..e446ab5a0 100644 --- a/tests/sympy/test_decorator.py +++ b/tests/sympy/test_decorator.py @@ -1,7 +1,7 @@ from __future__ import annotations import inspect -from typing import Any +from typing import Any, ClassVar import pytest import sympy as sp @@ -113,6 +113,27 @@ def evaluate(self) -> sp.Expr: assert isinstance(q_value.m2, sp.Float) +def test_unevaluated_expression_classvar(): + @unevaluated_expression + class MyExpr(sp.Expr): + x: float + m: ClassVar[int] = 2 + + def evaluate(self) -> sp.Expr: + return self.x**self.m # type: ignore[return-value] + + x_expr = MyExpr(4) + assert x_expr.x is sp.Integer(4) + assert x_expr.m is 2 # noqa: F632 + + y_expr = MyExpr(5) + assert x_expr.doit() == 4**2 + assert y_expr.doit() == 5**2 + MyExpr.m = 3 + assert x_expr.doit() == 4**3 + assert y_expr.doit() == 5**3 + + def test_unevaluated_expression_callable(): @unevaluated_expression(implement_doit=False) class Squared(sp.Expr):