Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: organize unevaluated_expression() test functions #377

Merged
merged 6 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/sympy/decorator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Required to set mypy options for the tests folder."""
82 changes: 82 additions & 0 deletions tests/sympy/decorator/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations

import inspect

import pytest
import sympy as sp

from ampform.sympy._decorator import (
_check_has_implementation,
_implement_latex_repr,
_implement_new_method,
_insert_args_in_signature,
_set_assumptions,
)


def test_check_has_implementation():
with pytest.raises(ValueError, match=r"must have an evaluate\(\) method"):

@_check_has_implementation
class MyExpr1: # pyright: ignore[reportUnusedClass]
pass

with pytest.raises(TypeError, match=r"evaluate\(\) must be a callable method"):

@_check_has_implementation
class MyExpr2: # pyright: ignore[reportUnusedClass]
evaluate = "test"


def test_implement_latex_repr():
@_implement_latex_repr
@_implement_new_method
class MyExpr(sp.Expr):
a: sp.Symbol
b: sp.Symbol
_latex_repr_ = R"f\left({a}, {b}\right)"

alpha, phi = sp.symbols("alpha phi")
expr = MyExpr(alpha, sp.cos(phi))
assert sp.latex(expr) == R"f\left(\alpha, \cos{\left(\phi \right)}\right)"


def test_implement_new_method():
@_implement_new_method
class MyExpr(sp.Expr):
a: int
b: int
c: int

with pytest.raises(
ValueError, match=r"^Expecting 3 positional arguments \(a, b, c\), but got 4$"
):
MyExpr(1, 2, 3, 4) # type: ignore[call-arg]
with pytest.raises(ValueError, match=r"^Missing constructor arguments: c$"):
MyExpr(1, 2) # type: ignore[call-arg]
expr = MyExpr(1, 2, 3)
assert expr.args == (1, 2, 3)
expr = MyExpr(1, b=2, c=3)
assert expr.args == (1, 2, 3)


def test_insert_args_in_signature():
parameters = ["a", "b"]

@_insert_args_in_signature(parameters)
def my_func(*args, **kwargs) -> int:
return 1

signature = inspect.signature(my_func)
assert list(signature.parameters) == [*parameters, "args", "kwargs"]
assert signature.return_annotation == "int"


def test_set_assumptions():
@_set_assumptions(commutative=True, negative=False, real=True)
class MySqrt: ...

expr = MySqrt()
assert expr.is_commutative # type: ignore[attr-defined]
assert not expr.is_negative # type: ignore[attr-defined]
assert expr.is_real # type: ignore[attr-defined]
Original file line number Diff line number Diff line change
Expand Up @@ -3,117 +3,12 @@
import inspect
from typing import Any, ClassVar

import pytest
import sympy as sp

from ampform.sympy._decorator import (
_check_has_implementation,
_implement_latex_repr,
_implement_new_method,
_insert_args_in_signature,
_set_assumptions,
unevaluated_expression,
)
from ampform.sympy._decorator import unevaluated_expression


def test_check_implementation():
with pytest.raises(ValueError, match=r"must have an evaluate\(\) method"):

@_check_has_implementation
class MyExpr1: # pyright: ignore[reportUnusedClass]
pass

with pytest.raises(TypeError, match=r"evaluate\(\) must be a callable method"):

@_check_has_implementation
class MyExpr2: # pyright: ignore[reportUnusedClass]
evaluate = "test"


def test_implement_latex_repr():
@_implement_latex_repr
@_implement_new_method
class MyExpr(sp.Expr):
a: sp.Symbol
b: sp.Symbol
_latex_repr_ = R"f\left({a}, {b}\right)"

alpha, phi = sp.symbols("alpha phi")
expr = MyExpr(alpha, sp.cos(phi))
assert sp.latex(expr) == R"f\left(\alpha, \cos{\left(\phi \right)}\right)"


def test_implement_new_method():
@_implement_new_method
class MyExpr(sp.Expr):
a: int
b: int
c: int

with pytest.raises(
ValueError, match=r"^Expecting 3 positional arguments \(a, b, c\), but got 4$"
):
MyExpr(1, 2, 3, 4) # type: ignore[call-arg]
with pytest.raises(ValueError, match=r"^Missing constructor arguments: c$"):
MyExpr(1, 2) # type: ignore[call-arg]
expr = MyExpr(1, 2, 3)
assert expr.args == (1, 2, 3)
expr = MyExpr(1, b=2, c=3)
assert expr.args == (1, 2, 3)


def test_insert_args_in_signature():
parameters = ["a", "b"]

@_insert_args_in_signature(parameters)
def my_func(*args, **kwargs) -> int:
return 1

signature = inspect.signature(my_func)
assert list(signature.parameters) == [*parameters, "args", "kwargs"]
assert signature.return_annotation == "int"


def test_set_assumptions():
@_set_assumptions(commutative=True, negative=False, real=True)
class MySqrt: ...

expr = MySqrt()
assert expr.is_commutative # type: ignore[attr-defined]
assert not expr.is_negative # type: ignore[attr-defined]
assert expr.is_real # type: ignore[attr-defined]


def test_unevaluated_expression():
@unevaluated_expression
class BreakupMomentum(sp.Expr):
s: Any
m1: Any
m2: Any
_latex_repr_ = R"q\left({s}\right)"

def evaluate(self) -> sp.Expr:
s, m1, m2 = self.args
return sp.sqrt((s - (m1 + m2) ** 2) * (s - (m1 - m2) ** 2)) # type: ignore[operator]

m0, ma, mb = sp.symbols("m0 m_a m_b")
expr = BreakupMomentum(m0**2, ma, mb)
assert expr.s is m0**2
assert expr.m1 is ma
assert expr.m2 is mb
assert expr.is_commutative is True
args_str = list(inspect.signature(expr.__new__).parameters)
assert args_str == ["s", "m1", "m2", "args", "evaluate", "kwargs"]
latex = sp.latex(expr)
assert latex == R"q\left(m_{0}^{2}\right)"

q_value = BreakupMomentum(1, m1=0.2, m2=0.4)
assert isinstance(q_value.s, sp.Integer)
assert isinstance(q_value.m1, sp.Float)
assert isinstance(q_value.m2, sp.Float)


def test_unevaluated_expression_classvar():
def test_classvar_behavior():
@unevaluated_expression
class MyExpr(sp.Expr):
x: float
Expand All @@ -134,7 +29,24 @@ def evaluate(self) -> sp.Expr:
assert y_expr.doit() == 5**3


def test_unevaluated_expression_default_argument():
def test_default_argument():
@unevaluated_expression
class MyExpr(sp.Expr):
x: Any
m: int = 2

def evaluate(self) -> sp.Expr:
return self.x**self.m

expr1 = MyExpr(x=5)
assert str(expr1) == "MyExpr(5, 2)"
assert expr1.doit() == 5**2

expr2 = MyExpr(4, 3)
assert expr2.doit() == 4**3


def test_default_argument_with_classvar():
@unevaluated_expression
class FunkyPower(sp.Expr):
x: Any
Expand Down Expand Up @@ -171,7 +83,7 @@ def evaluate(self) -> sp.Expr:
assert expr.default_return is half


def test_unevaluated_expression_callable():
def test_no_implement_doit():
@unevaluated_expression(implement_doit=False)
class Squared(sp.Expr):
x: Any
Expand All @@ -192,18 +104,30 @@ class MySqrt(sp.Expr):
assert expr.is_complex # type: ignore[attr-defined]


def test_unevaluated_expression_default_args():
def test_symbols_and_no_symbols():
@unevaluated_expression
class MyExpr(sp.Expr):
x: Any
m: int = 2
class BreakupMomentum(sp.Expr):
s: Any
m1: Any
m2: Any
_latex_repr_ = R"q\left({s}\right)"

def evaluate(self) -> sp.Expr:
return self.x**self.m
s, m1, m2 = self.args
return sp.sqrt((s - (m1 + m2) ** 2) * (s - (m1 - m2) ** 2)) # type: ignore[operator]

expr1 = MyExpr(x=5)
assert str(expr1) == "MyExpr(5, 2)"
assert expr1.doit() == 5**2
m0, ma, mb = sp.symbols("m0 m_a m_b")
expr = BreakupMomentum(m0**2, ma, mb)
assert expr.s is m0**2
assert expr.m1 is ma
assert expr.m2 is mb
assert expr.is_commutative is True
args_str = list(inspect.signature(expr.__new__).parameters)
assert args_str == ["s", "m1", "m2", "args", "evaluate", "kwargs"]
latex = sp.latex(expr)
assert latex == R"q\left(m_{0}^{2}\right)"

expr2 = MyExpr(4, 3)
assert expr2.doit() == 4**3
q_value = BreakupMomentum(1, m1=0.2, m2=0.4)
assert isinstance(q_value.s, sp.Integer)
assert isinstance(q_value.m1, sp.Float)
assert isinstance(q_value.m2, sp.Float)
Loading