From a1a6a42f23308263b3a4dfe0da8f278671cf8dd3 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 24 Nov 2023 10:34:17 +0100 Subject: [PATCH] ENH: make implementation method public as `evaluate()` --- src/ampform/kinematics/phasespace.py | 43 ++++++++++++---------------- src/ampform/sympy/_decorator.py | 12 ++++---- tests/sympy/test_decorator.py | 10 +++---- 3 files changed, 30 insertions(+), 35 deletions(-) diff --git a/src/ampform/kinematics/phasespace.py b/src/ampform/kinematics/phasespace.py index f82567ae5..8796a8812 100644 --- a/src/ampform/kinematics/phasespace.py +++ b/src/ampform/kinematics/phasespace.py @@ -5,23 +5,25 @@ from __future__ import annotations +from typing import Any + import sympy as sp -from ampform.sympy import ( - UnevaluatedExpression, - create_expression, - implement_doit_method, - make_commutative, -) +from ampform.sympy import unevaluated_expression -@make_commutative -@implement_doit_method -class Kibble(UnevaluatedExpression): +@unevaluated_expression +class Kibble(sp.Expr): """Kibble function for determining the phase space region.""" - def __new__(cls, sigma1, sigma2, sigma3, m0, m1, m2, m3, **hints) -> Kibble: - return create_expression(cls, sigma1, sigma2, sigma3, m0, m1, m2, m3, **hints) + sigma1: Any + sigma2: Any + sigma3: Any + m0: Any + m1: Any + m2: Any + m3: Any + _latex_repr_ = R"\phi\left({sigma1}, {sigma2}\right)" def evaluate(self) -> Kallen: sigma1, sigma2, sigma3, m0, m1, m2, m3 = self.args @@ -31,27 +33,20 @@ def evaluate(self) -> Kallen: Kallen(sigma1, m1**2, m0**2), # type: ignore[operator] ) - def _latex(self, printer, *args): - sigma1, sigma2, *_ = map(printer._print, self.args) - return Rf"\phi\left({sigma1}, {sigma2}\right)" - -@make_commutative -@implement_doit_method -class Kallen(UnevaluatedExpression): +@unevaluated_expression +class Kallen(sp.Expr): """Källén function, used for computing break-up momenta.""" - def __new__(cls, x, y, z, **hints) -> Kallen: - return create_expression(cls, x, y, z, **hints) + x: Any + y: Any + z: Any + _latex_repr_ = R"\lambda\left({x}, {y}, {z}\right)" def evaluate(self) -> sp.Expr: x, y, z = self.args return x**2 + y**2 + z**2 - 2 * x * y - 2 * y * z - 2 * z * x # type: ignore[operator] - def _latex(self, printer, *args): - x, y, z = map(printer._print, self.args) - return Rf"\lambda\left({x}, {y}, {z}\right)" - def is_within_phasespace( sigma1, sigma2, m0, m1, m2, m3, outside_value=sp.nan diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index 1d8031805..04cd46795 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -80,7 +80,7 @@ def unevaluated_expression( # type: ignore[misc] ... y: sp.Symbol ... _latex_repr_ = R"z\left({x}, {y}\right)" ... - ... def _implementation_(self) -> sp.Expr: + ... def evaluate(self) -> sp.Expr: ... x, y = self.args ... return x**2 + y**2 ... @@ -139,7 +139,7 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: for name, value in zip(attr_names, args): setattr(expr, name, value) if evaluate: - return expr._implementation_() + return expr.evaluate() return expr cls.__new__ = new_method # type: ignore[method-assign] @@ -201,7 +201,7 @@ def _implement_doit(cls: type[ExprClass]) -> type[ExprClass]: @functools.wraps(cls.doit) def doit_method(self, deep: bool = True) -> sp.Expr: - expr = self._implementation_() + expr = self.evaluate() if deep: return expr.doit() return expr @@ -211,12 +211,12 @@ def doit_method(self, deep: bool = True) -> sp.Expr: def _check_has_implementation(cls: type) -> None: - implementation_method = getattr(cls, "_implementation_", None) + implementation_method = getattr(cls, "evaluate", None) if implementation_method is None: - msg = "Decorated class must have an _implementation_ method" + msg = "Decorated class must have an evaluate() method" raise ValueError(msg) if not callable(implementation_method): - msg = "_implementation_ must be a callable method" + msg = "evaluate must be a callable method" raise TypeError(msg) diff --git a/tests/sympy/test_decorator.py b/tests/sympy/test_decorator.py index bc38773e9..1f75869b3 100644 --- a/tests/sympy/test_decorator.py +++ b/tests/sympy/test_decorator.py @@ -17,17 +17,17 @@ def test_check_implementation(): - with pytest.raises(ValueError, match="must have an _implementation_ method"): + with pytest.raises(ValueError, match="must have an evaluate() method"): @_check_has_implementation class MyExpr1: # pyright: ignore[reportUnusedClass] pass - with pytest.raises(TypeError, match="_implementation_ must be a callable method"): + with pytest.raises(TypeError, match="evaluate()s must be a callable method"): @_check_has_implementation class MyExpr2: # pyright: ignore[reportUnusedClass] - _implementation_ = "test" + evaluate = "test" def test_implement_latex_repr(): @@ -84,7 +84,7 @@ class BreakupMomentum(sp.Expr): m2: sp.Basic _latex_repr_ = R"q\left({s}\right)" - def _implementation_(self) -> sp.Expr: + def evaluate(self) -> sp.Expr: s, m1, m2 = self.args return sp.sqrt((s - (m1 + m2) ** 2) * (s - (m1 - m2) ** 2)) # type: ignore[operator] @@ -101,7 +101,7 @@ def test_unevaluated_expression_callable(): class Squared(sp.Expr): x: Any - def _implementation_(self) -> sp.Expr: + def evaluate(self) -> sp.Expr: return self.x**2 sqrt = Squared(2)