Skip to content

Commit

Permalink
FEAT: support class attributes in unevaluated_expression (#375)
Browse files Browse the repository at this point in the history
* BREAK: switch arguments of `BlattWeisskopfSquared`
* MAINT: rewrite `BlattWeisskopfSquared` with decorator
* MAINT: test ClassVar of `BlattWeisskopfSquared`
  • Loading branch information
redeboer committed Dec 22, 2023
1 parent ee6e9ea commit feb5015
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 34 deletions.
4 changes: 2 additions & 2 deletions docs/_extend_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def extend_docstrings() -> None:
def extend_BlattWeisskopfSquared() -> None:
from ampform.dynamics import BlattWeisskopfSquared

L = sp.Symbol("L", integer=True)
z = sp.Symbol("z", real=True)
expr = BlattWeisskopfSquared(L, z)
L = sp.Symbol("L", integer=True)
expr = BlattWeisskopfSquared(z, angular_momentum=L)
_append_latex_doit_definition(expr, deep=True, full_width=True)


Expand Down
4 changes: 2 additions & 2 deletions docs/usage/dynamics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
"\n",
"L = sp.Symbol(\"L\", integer=True)\n",
"z = sp.Symbol(\"z\", real=True)\n",
"ff2 = BlattWeisskopfSquared(L, z)\n",
"ff2 = BlattWeisskopfSquared(z, L)\n",
"Math(sp.multiline_latex(ff2, ff2.doit(), environment=\"eqnarray\"))"
]
},
Expand All @@ -183,7 +183,7 @@
"m, m_a, m_b, d = sp.symbols(\"m, m_a, m_b, d\")\n",
"s = m**2\n",
"q_squared = BreakupMomentumSquared(s, m_a, m_b)\n",
"ff2 = BlattWeisskopfSquared(L, z=q_squared * d**2)"
"ff2 = BlattWeisskopfSquared(q_squared * d**2, angular_momentum=L)"
]
},
{
Expand Down
36 changes: 16 additions & 20 deletions src/ampform/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# cspell:ignore asner mhash
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, ClassVar

import sympy as sp
from sympy.core.basic import _aresame
Expand All @@ -25,26 +25,26 @@
)
from ampform.sympy import (
UnevaluatedExpression,
create_expression,
implement_doit_method,
unevaluated_expression,
)

if TYPE_CHECKING:
from sympy.printing.latex import LatexPrinter


@implement_doit_method
class BlattWeisskopfSquared(UnevaluatedExpression):
@unevaluated_expression
class BlattWeisskopfSquared(sp.Expr):
# cspell:ignore pychyGekoppeltePartialwellenanalyseAnnihilationen
r"""Blatt-Weisskopf function :math:`B_L^2(z)`, up to :math:`L \leq 8`.
Args:
angular_momentum: Angular momentum :math:`L` of the decaying particle.
z: Argument of the Blatt-Weisskopf function :math:`B_L^2(z)`. A usual
choice is :math:`z = (d q)^2` with :math:`d` the impact parameter and
:math:`q` the breakup-momentum (see `.BreakupMomentumSquared`).
angular_momentum: Angular momentum :math:`L` of the decaying particle.
Note that equal powers of :math:`z` appear in the nominator and the denominator,
while some sources have nominator :math:`1`, instead of :math:`z^L`. Compare for
instance Equation (50.27) in :pdg-review:`2021; Resonances; p.9`.
Expand All @@ -57,20 +57,20 @@ class BlattWeisskopfSquared(UnevaluatedExpression):
See also :ref:`usage/dynamics:Form factor`.
"""
is_commutative = True
max_angular_momentum: int | None = None
z: Any
angular_momentum: Any
_latex_repr_ = R"B_{{{angular_momentum}}}^2\left({z}\right)"

max_angular_momentum: ClassVar[int | None] = None
"""Limit the maximum allowed angular momentum :math:`L`.
This improves performance when :math:`L` is a `~sympy.core.symbol.Symbol` and you
are note interested in higher angular momenta.
"""

def __new__(cls, angular_momentum, z, **hints) -> BlattWeisskopfSquared:
return create_expression(cls, angular_momentum, z, **hints)

def evaluate(self) -> sp.Expr:
angular_momentum: sp.Expr = self.args[0] # type: ignore[assignment]
z: sp.Expr = self.args[1] # type: ignore[assignment]
z: sp.Expr = self.args[0] # type: ignore[assignment]
angular_momentum: sp.Expr = self.args[1] # type: ignore[assignment]
cases: dict[int, sp.Expr] = {
0: sp.S.One,
1: 2 * z / (z + 1),
Expand Down Expand Up @@ -138,10 +138,6 @@ def evaluate(self) -> sp.Expr:
if self.max_angular_momentum is None or value <= self.max_angular_momentum
])

def _latex(self, printer: LatexPrinter, *args) -> str:
angular_momentum, z = tuple(map(printer._print, self.args))
return Rf"B_{{{angular_momentum}}}^2\left({z}\right)"


@implement_doit_method
class EnergyDependentWidth(UnevaluatedExpression):
Expand Down Expand Up @@ -208,12 +204,12 @@ def evaluate(self) -> sp.Expr:
q_squared = BreakupMomentumSquared(s, m_a, m_b)
q0_squared = BreakupMomentumSquared(mass0**2, m_a, m_b) # type: ignore[operator]
form_factor_sq = BlattWeisskopfSquared(
q_squared * meson_radius**2, # type: ignore[operator]
angular_momentum,
z=q_squared * meson_radius**2, # type: ignore[operator]
)
form_factor0_sq = BlattWeisskopfSquared(
q0_squared * meson_radius**2, # type: ignore[operator]
angular_momentum,
z=q0_squared * meson_radius**2, # type: ignore[operator]
)
rho = self.phsp_factor(s, m_a, m_b)
rho0 = self.phsp_factor(mass0**2, m_a, m_b) # type: ignore[operator]
Expand Down Expand Up @@ -303,5 +299,5 @@ def formulate_form_factor(s, m_a, m_b, angular_momentum, meson_radius) -> sp.Exp
`~sympy.functions.elementary.miscellaneous.sqrt` of a `.BlattWeisskopfSquared`.
"""
q_squared = BreakupMomentumSquared(s, m_a, m_b)
ff_squared = BlattWeisskopfSquared(angular_momentum, z=q_squared * meson_radius**2)
ff_squared = BlattWeisskopfSquared(q_squared * meson_radius**2, angular_momentum)
return sp.sqrt(ff_squared)
7 changes: 6 additions & 1 deletion src/ampform/sympy/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,19 @@ def _get_attribute_names(cls: type) -> tuple[str, ...]:
... a: int
... b: int
... _c: int
... n: ClassVar[int] = 2
...
... def print(self): ...
...
>>> _get_attribute_names(MyClass)
('a', 'b')
"""
return tuple(
k for k in cls.__annotations__ if not callable(k) if not k.startswith("_")
k
for k, v in cls.__annotations__.items()
if not callable(k)
if not k.startswith("_")
if not str(v).startswith("ClassVar")
)


Expand Down
14 changes: 13 additions & 1 deletion tests/dynamics/test_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TestBlattWeisskopfSquared:
def test_max_angular_momentum(self):
z = sp.Symbol("z")
angular_momentum = sp.Symbol("L", integer=True)
form_factor = BlattWeisskopfSquared(angular_momentum, z=z)
form_factor = BlattWeisskopfSquared(z, angular_momentum)
form_factor_9 = form_factor.subs(angular_momentum, 8).evaluate()
factor, z_power, _ = form_factor_9.args
assert factor == 4392846440677
Expand All @@ -35,6 +35,18 @@ def test_max_angular_momentum(self):
(1, sp.Eq(angular_momentum, 0)),
(2 * z / (z + 1), sp.Eq(angular_momentum, 1)),
)
BlattWeisskopfSquared.max_angular_momentum = None

def test_unevaluated_expression(self):
z = sp.Symbol("z")
ff1 = BlattWeisskopfSquared(z, angular_momentum=1)
ff2 = BlattWeisskopfSquared(z, angular_momentum=2)
assert ff1.max_angular_momentum is None
assert ff2.max_angular_momentum is None
BlattWeisskopfSquared.max_angular_momentum = 3
assert ff1.max_angular_momentum is 3 # noqa: F632
assert ff2.max_angular_momentum is 3 # noqa: F632
BlattWeisskopfSquared.max_angular_momentum = None


class TestEnergyDependentWidth:
Expand Down
2 changes: 1 addition & 1 deletion tests/dynamics/test_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_pickle():
assert expr == imported_expr

# Pickle classes derived from UnevaluatedExpression
expr = BlattWeisskopfSquared(angular_momentum, z=z)
expr = BlattWeisskopfSquared(z, angular_momentum)
pickled_obj = pickle.dumps(expr)
imported_expr = pickle.loads(pickled_obj) # noqa: S301
assert expr == imported_expr
Expand Down
12 changes: 6 additions & 6 deletions tests/sympy/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,17 @@ def test_get_readable_hash_large(amplitude_model: tuple[str, HelicityModel]):
# https://github.com/ComPWA/ampform/actions/runs/3277058875/jobs/5393849802
# https://github.com/ComPWA/ampform/actions/runs/3277143883/jobs/5394043014
expected_hash = {
"canonical-helicity": "pythonhashseed-0-6040455869260657745",
"helicity": "pythonhashseed-0-1928646339459384503",
"canonical-helicity": "pythonhashseed-0-3873186712292274641",
"helicity": "pythonhashseed-0-8800154542426799839",
}[formalism]
elif sys.version_info >= (3, 11):
expected_hash = {
"canonical-helicity": "pythonhashseed-0+409069872540431022",
"helicity": "pythonhashseed-0-8907705932662936900",
"canonical-helicity": "pythonhashseed-0+4035132515642199515",
"helicity": "pythonhashseed-0-2843057473565885663",
}[formalism]
else:
expected_hash = {
"canonical-helicity": "pythonhashseed-0-7143983882032045549",
"helicity": "pythonhashseed-0+3357246175053927117",
"canonical-helicity": "pythonhashseed-0+3420919389670627445",
"helicity": "pythonhashseed-0-6681863313351758450",
}[formalism]
assert get_readable_hash(model.expression) == expected_hash
23 changes: 22 additions & 1 deletion tests/sympy/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import inspect
from typing import Any
from typing import Any, ClassVar

import pytest
import sympy as sp
Expand Down Expand Up @@ -113,6 +113,27 @@ def evaluate(self) -> sp.Expr:
assert isinstance(q_value.m2, sp.Float)


def test_unevaluated_expression_classvar():
@unevaluated_expression
class MyExpr(sp.Expr):
x: float
m: ClassVar[int] = 2

def evaluate(self) -> sp.Expr:
return self.x**self.m # type: ignore[return-value]

x_expr = MyExpr(4)
assert x_expr.x is sp.Integer(4)
assert x_expr.m is 2 # noqa: F632

y_expr = MyExpr(5)
assert x_expr.doit() == 4**2
assert y_expr.doit() == 5**2
MyExpr.m = 3
assert x_expr.doit() == 4**3
assert y_expr.doit() == 5**3


def test_unevaluated_expression_callable():
@unevaluated_expression(implement_doit=False)
class Squared(sp.Expr):
Expand Down

0 comments on commit feb5015

Please sign in to comment.