diff --git a/tests/sympy/decorator/__init__.py b/tests/sympy/decorator/__init__.py new file mode 100644 index 000000000..948df262f --- /dev/null +++ b/tests/sympy/decorator/__init__.py @@ -0,0 +1 @@ +"""Required to set mypy options for the tests folder.""" diff --git a/tests/sympy/decorator/test_helpers.py b/tests/sympy/decorator/test_helpers.py new file mode 100644 index 000000000..107ecdf47 --- /dev/null +++ b/tests/sympy/decorator/test_helpers.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import inspect + +import pytest +import sympy as sp + +from ampform.sympy._decorator import ( + _check_has_implementation, + _implement_latex_repr, + _implement_new_method, + _insert_args_in_signature, + _set_assumptions, +) + + +def test_check_has_implementation(): + with pytest.raises(ValueError, match=r"must have an evaluate\(\) method"): + + @_check_has_implementation + class MyExpr1: # pyright: ignore[reportUnusedClass] + pass + + with pytest.raises(TypeError, match=r"evaluate\(\) must be a callable method"): + + @_check_has_implementation + class MyExpr2: # pyright: ignore[reportUnusedClass] + evaluate = "test" + + +def test_implement_latex_repr(): + @_implement_latex_repr + @_implement_new_method + class MyExpr(sp.Expr): + a: sp.Symbol + b: sp.Symbol + _latex_repr_ = R"f\left({a}, {b}\right)" + + alpha, phi = sp.symbols("alpha phi") + expr = MyExpr(alpha, sp.cos(phi)) + assert sp.latex(expr) == R"f\left(\alpha, \cos{\left(\phi \right)}\right)" + + +def test_implement_new_method(): + @_implement_new_method + class MyExpr(sp.Expr): + a: int + b: int + c: int + + with pytest.raises( + ValueError, match=r"^Expecting 3 positional arguments \(a, b, c\), but got 4$" + ): + MyExpr(1, 2, 3, 4) # type: ignore[call-arg] + with pytest.raises(ValueError, match=r"^Missing constructor arguments: c$"): + MyExpr(1, 2) # type: ignore[call-arg] + expr = MyExpr(1, 2, 3) + assert expr.args == (1, 2, 3) + expr = MyExpr(1, b=2, c=3) + assert expr.args == (1, 2, 3) + + +def test_insert_args_in_signature(): + parameters = ["a", "b"] + + @_insert_args_in_signature(parameters) + def my_func(*args, **kwargs) -> int: + return 1 + + signature = inspect.signature(my_func) + assert list(signature.parameters) == [*parameters, "args", "kwargs"] + assert signature.return_annotation == "int" + + +def test_set_assumptions(): + @_set_assumptions(commutative=True, negative=False, real=True) + class MySqrt: ... + + expr = MySqrt() + assert expr.is_commutative # type: ignore[attr-defined] + assert not expr.is_negative # type: ignore[attr-defined] + assert expr.is_real # type: ignore[attr-defined] diff --git a/tests/sympy/test_decorator.py b/tests/sympy/decorator/test_unevaluated_expression.py similarity index 57% rename from tests/sympy/test_decorator.py rename to tests/sympy/decorator/test_unevaluated_expression.py index b9e4209a2..f04de40bf 100644 --- a/tests/sympy/test_decorator.py +++ b/tests/sympy/decorator/test_unevaluated_expression.py @@ -3,117 +3,12 @@ import inspect from typing import Any, ClassVar -import pytest import sympy as sp -from ampform.sympy._decorator import ( - _check_has_implementation, - _implement_latex_repr, - _implement_new_method, - _insert_args_in_signature, - _set_assumptions, - unevaluated_expression, -) +from ampform.sympy._decorator import unevaluated_expression -def test_check_implementation(): - with pytest.raises(ValueError, match=r"must have an evaluate\(\) method"): - - @_check_has_implementation - class MyExpr1: # pyright: ignore[reportUnusedClass] - pass - - with pytest.raises(TypeError, match=r"evaluate\(\) must be a callable method"): - - @_check_has_implementation - class MyExpr2: # pyright: ignore[reportUnusedClass] - evaluate = "test" - - -def test_implement_latex_repr(): - @_implement_latex_repr - @_implement_new_method - class MyExpr(sp.Expr): - a: sp.Symbol - b: sp.Symbol - _latex_repr_ = R"f\left({a}, {b}\right)" - - alpha, phi = sp.symbols("alpha phi") - expr = MyExpr(alpha, sp.cos(phi)) - assert sp.latex(expr) == R"f\left(\alpha, \cos{\left(\phi \right)}\right)" - - -def test_implement_new_method(): - @_implement_new_method - class MyExpr(sp.Expr): - a: int - b: int - c: int - - with pytest.raises( - ValueError, match=r"^Expecting 3 positional arguments \(a, b, c\), but got 4$" - ): - MyExpr(1, 2, 3, 4) # type: ignore[call-arg] - with pytest.raises(ValueError, match=r"^Missing constructor arguments: c$"): - MyExpr(1, 2) # type: ignore[call-arg] - expr = MyExpr(1, 2, 3) - assert expr.args == (1, 2, 3) - expr = MyExpr(1, b=2, c=3) - assert expr.args == (1, 2, 3) - - -def test_insert_args_in_signature(): - parameters = ["a", "b"] - - @_insert_args_in_signature(parameters) - def my_func(*args, **kwargs) -> int: - return 1 - - signature = inspect.signature(my_func) - assert list(signature.parameters) == [*parameters, "args", "kwargs"] - assert signature.return_annotation == "int" - - -def test_set_assumptions(): - @_set_assumptions(commutative=True, negative=False, real=True) - class MySqrt: ... - - expr = MySqrt() - assert expr.is_commutative # type: ignore[attr-defined] - assert not expr.is_negative # type: ignore[attr-defined] - assert expr.is_real # type: ignore[attr-defined] - - -def test_unevaluated_expression(): - @unevaluated_expression - class BreakupMomentum(sp.Expr): - s: Any - m1: Any - m2: Any - _latex_repr_ = R"q\left({s}\right)" - - def evaluate(self) -> sp.Expr: - s, m1, m2 = self.args - return sp.sqrt((s - (m1 + m2) ** 2) * (s - (m1 - m2) ** 2)) # type: ignore[operator] - - m0, ma, mb = sp.symbols("m0 m_a m_b") - expr = BreakupMomentum(m0**2, ma, mb) - assert expr.s is m0**2 - assert expr.m1 is ma - assert expr.m2 is mb - assert expr.is_commutative is True - args_str = list(inspect.signature(expr.__new__).parameters) - assert args_str == ["s", "m1", "m2", "args", "evaluate", "kwargs"] - latex = sp.latex(expr) - assert latex == R"q\left(m_{0}^{2}\right)" - - q_value = BreakupMomentum(1, m1=0.2, m2=0.4) - assert isinstance(q_value.s, sp.Integer) - assert isinstance(q_value.m1, sp.Float) - assert isinstance(q_value.m2, sp.Float) - - -def test_unevaluated_expression_classvar(): +def test_classvar_behavior(): @unevaluated_expression class MyExpr(sp.Expr): x: float @@ -134,7 +29,24 @@ def evaluate(self) -> sp.Expr: assert y_expr.doit() == 5**3 -def test_unevaluated_expression_default_argument(): +def test_default_argument(): + @unevaluated_expression + class MyExpr(sp.Expr): + x: Any + m: int = 2 + + def evaluate(self) -> sp.Expr: + return self.x**self.m + + expr1 = MyExpr(x=5) + assert str(expr1) == "MyExpr(5, 2)" + assert expr1.doit() == 5**2 + + expr2 = MyExpr(4, 3) + assert expr2.doit() == 4**3 + + +def test_default_argument_with_classvar(): @unevaluated_expression class FunkyPower(sp.Expr): x: Any @@ -171,7 +83,7 @@ def evaluate(self) -> sp.Expr: assert expr.default_return is half -def test_unevaluated_expression_callable(): +def test_no_implement_doit(): @unevaluated_expression(implement_doit=False) class Squared(sp.Expr): x: Any @@ -192,18 +104,30 @@ class MySqrt(sp.Expr): assert expr.is_complex # type: ignore[attr-defined] -def test_unevaluated_expression_default_args(): +def test_symbols_and_no_symbols(): @unevaluated_expression - class MyExpr(sp.Expr): - x: Any - m: int = 2 + class BreakupMomentum(sp.Expr): + s: Any + m1: Any + m2: Any + _latex_repr_ = R"q\left({s}\right)" def evaluate(self) -> sp.Expr: - return self.x**self.m + s, m1, m2 = self.args + return sp.sqrt((s - (m1 + m2) ** 2) * (s - (m1 - m2) ** 2)) # type: ignore[operator] - expr1 = MyExpr(x=5) - assert str(expr1) == "MyExpr(5, 2)" - assert expr1.doit() == 5**2 + m0, ma, mb = sp.symbols("m0 m_a m_b") + expr = BreakupMomentum(m0**2, ma, mb) + assert expr.s is m0**2 + assert expr.m1 is ma + assert expr.m2 is mb + assert expr.is_commutative is True + args_str = list(inspect.signature(expr.__new__).parameters) + assert args_str == ["s", "m1", "m2", "args", "evaluate", "kwargs"] + latex = sp.latex(expr) + assert latex == R"q\left(m_{0}^{2}\right)" - expr2 = MyExpr(4, 3) - assert expr2.doit() == 4**3 + q_value = BreakupMomentum(1, m1=0.2, m2=0.4) + assert isinstance(q_value.s, sp.Integer) + assert isinstance(q_value.m1, sp.Float) + assert isinstance(q_value.m2, sp.Float)