Skip to content

Commit

Permalink
FIX: sympify unevaluated_expression instance attributes (#374)
Browse files Browse the repository at this point in the history
* ENH: switch to tuples and improve signatures
* ENH: test sympification of `@unevaluated_expression()` attributes
* MAINT: remove redundant docstring in tests
* MAINT: sort test functions
  • Loading branch information
redeboer committed Dec 22, 2023
1 parent ce5d6ae commit ee6e9ea
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 22 deletions.
18 changes: 11 additions & 7 deletions src/ampform/sympy/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools
import inspect
import sys
from typing import TYPE_CHECKING, Callable, Iterable, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, overload

import sympy as sp

Expand Down Expand Up @@ -136,7 +136,7 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]:
attr_values, kwargs = _get_attribute_values(attr_names, *args, **kwargs)
attr_values = sp.sympify(attr_values)
expr = sp.Expr.__new__(cls, *attr_values, **kwargs)
for name, value in zip(attr_names, args):
for name, value in zip(attr_names, attr_values):
setattr(expr, name, value)
if evaluate:
return expr.evaluate()
Expand All @@ -146,7 +146,9 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]:
return cls


def _get_attribute_values(attr_names: list, *args, **kwargs) -> tuple[tuple, dict]:
def _get_attribute_values(
attr_names: tuple[str, ...], *args, **kwargs
) -> tuple[tuple, dict[str, Any]]:
if len(args) == len(attr_names):
return args, kwargs
if len(args) > len(attr_names):
Expand All @@ -156,7 +158,7 @@ def _get_attribute_values(attr_names: list, *args, **kwargs) -> tuple[tuple, dic
)
raise ValueError(msg)
attr_values = list(args)
remaining_attr_names = attr_names[len(args) :]
remaining_attr_names = list(attr_names[len(args) :])
for name in list(remaining_attr_names):
if name in kwargs:
attr_values.append(kwargs.pop(name))
Expand Down Expand Up @@ -247,7 +249,7 @@ def wrapper(*args, **kwargs):
return decorator


def _get_attribute_names(cls: type) -> list[str]:
def _get_attribute_names(cls: type) -> tuple[str, ...]:
"""Get the public attributes of a class with dataclass-like semantics.
>>> class MyClass:
Expand All @@ -258,9 +260,11 @@ def _get_attribute_names(cls: type) -> list[str]:
... def print(self): ...
...
>>> _get_attribute_names(MyClass)
['a', 'b']
('a', 'b')
"""
return [v for v in cls.__annotations__ if not callable(v) if not v.startswith("_")]
return tuple(
k for k in cls.__annotations__ if not callable(k) if not k.startswith("_")
)


@dataclass_transform()
Expand Down
37 changes: 22 additions & 15 deletions tests/sympy/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,22 @@ def my_func(*args, **kwargs) -> int:
assert signature.return_annotation == "int"


def test_set_assumptions():
@_set_assumptions(commutative=True, negative=False, real=True)
class MySqrt: ...

expr = MySqrt()
assert expr.is_commutative # type: ignore[attr-defined]
assert not expr.is_negative # type: ignore[attr-defined]
assert expr.is_real # type: ignore[attr-defined]


def test_unevaluated_expression():
@unevaluated_expression
class BreakupMomentum(sp.Expr):
r"""Breakup momentum of a two-body decay :math:`a \to 1+2`."""

s: sp.Basic
m1: sp.Basic
m2: sp.Basic
s: Any
m1: Any
m2: Any
_latex_repr_ = R"q\left({s}\right)"

def evaluate(self) -> sp.Expr:
Expand All @@ -90,11 +98,20 @@ def evaluate(self) -> sp.Expr:

m0, ma, mb = sp.symbols("m0 m_a m_b")
expr = BreakupMomentum(m0**2, ma, mb)
assert expr.s is m0**2
assert expr.m1 is ma
assert expr.m2 is mb
assert expr.is_commutative is True
args_str = list(inspect.signature(expr.__new__).parameters)
assert args_str == ["s", "m1", "m2", "args", "evaluate", "kwargs"]
latex = sp.latex(expr)
assert latex == R"q\left(m_{0}^{2}\right)"

q_value = BreakupMomentum(1, m1=0.2, m2=0.4)
assert isinstance(q_value.s, sp.Integer)
assert isinstance(q_value.m1, sp.Float)
assert isinstance(q_value.m2, sp.Float)


def test_unevaluated_expression_callable():
@unevaluated_expression(implement_doit=False)
Expand All @@ -115,13 +132,3 @@ class MySqrt(sp.Expr):
expr = MySqrt(-1)
assert expr.is_commutative
assert expr.is_complex # type: ignore[attr-defined]


def test_set_assumptions():
@_set_assumptions(commutative=True, negative=False, real=True)
class MySqrt: ...

expr = MySqrt()
assert expr.is_commutative # type: ignore[attr-defined]
assert not expr.is_negative # type: ignore[attr-defined]
assert expr.is_real # type: ignore[attr-defined]

0 comments on commit ee6e9ea

Please sign in to comment.