From ee6e9eae7c7530a97fd9a8083dfcc716ee66bd66 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 18 Dec 2023 13:50:14 +0100 Subject: [PATCH] FIX: sympify `unevaluated_expression` instance attributes (#374) * ENH: switch to tuples and improve signatures * ENH: test sympification of `@unevaluated_expression()` attributes * MAINT: remove redundant docstring in tests * MAINT: sort test functions --- src/ampform/sympy/_decorator.py | 18 +++++++++------- tests/sympy/test_decorator.py | 37 ++++++++++++++++++++------------- 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index e1f68400b..1a03edb9d 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -3,7 +3,7 @@ import functools import inspect import sys -from typing import TYPE_CHECKING, Callable, Iterable, TypeVar, overload +from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, overload import sympy as sp @@ -136,7 +136,7 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: attr_values, kwargs = _get_attribute_values(attr_names, *args, **kwargs) attr_values = sp.sympify(attr_values) expr = sp.Expr.__new__(cls, *attr_values, **kwargs) - for name, value in zip(attr_names, args): + for name, value in zip(attr_names, attr_values): setattr(expr, name, value) if evaluate: return expr.evaluate() @@ -146,7 +146,9 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: return cls -def _get_attribute_values(attr_names: list, *args, **kwargs) -> tuple[tuple, dict]: +def _get_attribute_values( + attr_names: tuple[str, ...], *args, **kwargs +) -> tuple[tuple, dict[str, Any]]: if len(args) == len(attr_names): return args, kwargs if len(args) > len(attr_names): @@ -156,7 +158,7 @@ def _get_attribute_values(attr_names: list, *args, **kwargs) -> tuple[tuple, dic ) raise ValueError(msg) attr_values = list(args) - remaining_attr_names = attr_names[len(args) :] + remaining_attr_names = list(attr_names[len(args) :]) for name in list(remaining_attr_names): if name in kwargs: attr_values.append(kwargs.pop(name)) @@ -247,7 +249,7 @@ def wrapper(*args, **kwargs): return decorator -def _get_attribute_names(cls: type) -> list[str]: +def _get_attribute_names(cls: type) -> tuple[str, ...]: """Get the public attributes of a class with dataclass-like semantics. >>> class MyClass: @@ -258,9 +260,11 @@ def _get_attribute_names(cls: type) -> list[str]: ... def print(self): ... ... >>> _get_attribute_names(MyClass) - ['a', 'b'] + ('a', 'b') """ - return [v for v in cls.__annotations__ if not callable(v) if not v.startswith("_")] + return tuple( + k for k in cls.__annotations__ if not callable(k) if not k.startswith("_") + ) @dataclass_transform() diff --git a/tests/sympy/test_decorator.py b/tests/sympy/test_decorator.py index d1228374a..40bf1d50a 100644 --- a/tests/sympy/test_decorator.py +++ b/tests/sympy/test_decorator.py @@ -74,14 +74,22 @@ def my_func(*args, **kwargs) -> int: assert signature.return_annotation == "int" +def test_set_assumptions(): + @_set_assumptions(commutative=True, negative=False, real=True) + class MySqrt: ... + + expr = MySqrt() + assert expr.is_commutative # type: ignore[attr-defined] + assert not expr.is_negative # type: ignore[attr-defined] + assert expr.is_real # type: ignore[attr-defined] + + def test_unevaluated_expression(): @unevaluated_expression class BreakupMomentum(sp.Expr): - r"""Breakup momentum of a two-body decay :math:`a \to 1+2`.""" - - s: sp.Basic - m1: sp.Basic - m2: sp.Basic + s: Any + m1: Any + m2: Any _latex_repr_ = R"q\left({s}\right)" def evaluate(self) -> sp.Expr: @@ -90,11 +98,20 @@ def evaluate(self) -> sp.Expr: m0, ma, mb = sp.symbols("m0 m_a m_b") expr = BreakupMomentum(m0**2, ma, mb) + assert expr.s is m0**2 + assert expr.m1 is ma + assert expr.m2 is mb + assert expr.is_commutative is True args_str = list(inspect.signature(expr.__new__).parameters) assert args_str == ["s", "m1", "m2", "args", "evaluate", "kwargs"] latex = sp.latex(expr) assert latex == R"q\left(m_{0}^{2}\right)" + q_value = BreakupMomentum(1, m1=0.2, m2=0.4) + assert isinstance(q_value.s, sp.Integer) + assert isinstance(q_value.m1, sp.Float) + assert isinstance(q_value.m2, sp.Float) + def test_unevaluated_expression_callable(): @unevaluated_expression(implement_doit=False) @@ -115,13 +132,3 @@ class MySqrt(sp.Expr): expr = MySqrt(-1) assert expr.is_commutative assert expr.is_complex # type: ignore[attr-defined] - - -def test_set_assumptions(): - @_set_assumptions(commutative=True, negative=False, real=True) - class MySqrt: ... - - expr = MySqrt() - assert expr.is_commutative # type: ignore[attr-defined] - assert not expr.is_negative # type: ignore[attr-defined] - assert expr.is_real # type: ignore[attr-defined]