Skip to content

Commit

Permalink
ENH: make implementation method public as evaluate()
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Nov 24, 2023
1 parent af503b1 commit a1a6a42
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 35 deletions.
43 changes: 19 additions & 24 deletions src/ampform/kinematics/phasespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,25 @@

from __future__ import annotations

from typing import Any

import sympy as sp

from ampform.sympy import (
UnevaluatedExpression,
create_expression,
implement_doit_method,
make_commutative,
)
from ampform.sympy import unevaluated_expression


@make_commutative
@implement_doit_method
class Kibble(UnevaluatedExpression):
@unevaluated_expression
class Kibble(sp.Expr):
"""Kibble function for determining the phase space region."""

def __new__(cls, sigma1, sigma2, sigma3, m0, m1, m2, m3, **hints) -> Kibble:
return create_expression(cls, sigma1, sigma2, sigma3, m0, m1, m2, m3, **hints)
sigma1: Any
sigma2: Any
sigma3: Any
m0: Any
m1: Any
m2: Any
m3: Any
_latex_repr_ = R"\phi\left({sigma1}, {sigma2}\right)"

def evaluate(self) -> Kallen:
sigma1, sigma2, sigma3, m0, m1, m2, m3 = self.args
Expand All @@ -31,27 +33,20 @@ def evaluate(self) -> Kallen:
Kallen(sigma1, m1**2, m0**2), # type: ignore[operator]
)

def _latex(self, printer, *args):
sigma1, sigma2, *_ = map(printer._print, self.args)
return Rf"\phi\left({sigma1}, {sigma2}\right)"


@make_commutative
@implement_doit_method
class Kallen(UnevaluatedExpression):
@unevaluated_expression
class Kallen(sp.Expr):
"""Källén function, used for computing break-up momenta."""

def __new__(cls, x, y, z, **hints) -> Kallen:
return create_expression(cls, x, y, z, **hints)
x: Any
y: Any
z: Any
_latex_repr_ = R"\lambda\left({x}, {y}, {z}\right)"

def evaluate(self) -> sp.Expr:
x, y, z = self.args
return x**2 + y**2 + z**2 - 2 * x * y - 2 * y * z - 2 * z * x # type: ignore[operator]

def _latex(self, printer, *args):
x, y, z = map(printer._print, self.args)
return Rf"\lambda\left({x}, {y}, {z}\right)"


def is_within_phasespace(
sigma1, sigma2, m0, m1, m2, m3, outside_value=sp.nan
Expand Down
12 changes: 6 additions & 6 deletions src/ampform/sympy/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def unevaluated_expression( # type: ignore[misc]
... y: sp.Symbol
... _latex_repr_ = R"z\left({x}, {y}\right)"
...
... def _implementation_(self) -> sp.Expr:
... def evaluate(self) -> sp.Expr:
... x, y = self.args
... return x**2 + y**2
...
Expand Down Expand Up @@ -139,7 +139,7 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]:
for name, value in zip(attr_names, args):
setattr(expr, name, value)
if evaluate:
return expr._implementation_()
return expr.evaluate()
return expr

cls.__new__ = new_method # type: ignore[method-assign]
Expand Down Expand Up @@ -201,7 +201,7 @@ def _implement_doit(cls: type[ExprClass]) -> type[ExprClass]:

@functools.wraps(cls.doit)
def doit_method(self, deep: bool = True) -> sp.Expr:
expr = self._implementation_()
expr = self.evaluate()
if deep:
return expr.doit()
return expr
Expand All @@ -211,12 +211,12 @@ def doit_method(self, deep: bool = True) -> sp.Expr:


def _check_has_implementation(cls: type) -> None:
implementation_method = getattr(cls, "_implementation_", None)
implementation_method = getattr(cls, "evaluate", None)
if implementation_method is None:
msg = "Decorated class must have an _implementation_ method"
msg = "Decorated class must have an evaluate() method"
raise ValueError(msg)
if not callable(implementation_method):
msg = "_implementation_ must be a callable method"
msg = "evaluate must be a callable method"
raise TypeError(msg)


Expand Down
10 changes: 5 additions & 5 deletions tests/sympy/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@


def test_check_implementation():
with pytest.raises(ValueError, match="must have an _implementation_ method"):
with pytest.raises(ValueError, match="must have an evaluate() method"):

@_check_has_implementation
class MyExpr1: # pyright: ignore[reportUnusedClass]
pass

with pytest.raises(TypeError, match="_implementation_ must be a callable method"):
with pytest.raises(TypeError, match="evaluate()s must be a callable method"):

@_check_has_implementation
class MyExpr2: # pyright: ignore[reportUnusedClass]
_implementation_ = "test"
evaluate = "test"


def test_implement_latex_repr():
Expand Down Expand Up @@ -84,7 +84,7 @@ class BreakupMomentum(sp.Expr):
m2: sp.Basic
_latex_repr_ = R"q\left({s}\right)"

def _implementation_(self) -> sp.Expr:
def evaluate(self) -> sp.Expr:
s, m1, m2 = self.args
return sp.sqrt((s - (m1 + m2) ** 2) * (s - (m1 - m2) ** 2)) # type: ignore[operator]

Expand All @@ -101,7 +101,7 @@ def test_unevaluated_expression_callable():
class Squared(sp.Expr):
x: Any

def _implementation_(self) -> sp.Expr:
def evaluate(self) -> sp.Expr:
return self.x**2

sqrt = Squared(2)
Expand Down

0 comments on commit a1a6a42

Please sign in to comment.