From c4452dd78468601ec86a2cddb0b4e6f6f587db1c Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 19 Dec 2023 10:52:11 +0100 Subject: [PATCH 01/13] DOC: add docstring to `_get_attribute_values()` --- src/ampform/sympy/_decorator.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index a3e42ce2c..1251d561c 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -149,6 +149,17 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: def _get_attribute_values( cls: type[ExprClass], attr_names: tuple[str, ...], *args, **kwargs ) -> tuple[tuple, dict[str, Any]]: + """Extract the attribute values from the constructor arguments. + + Returns a `tuple` of: + + 1. the extracted, ordered attributes as requested by :code:`attr_names`, + 2. a `dict` of remaining keyword arguments that can be used hints for the + constructed :class:`sp.Expr` instance. + + An attempt is made to get any missing attributes from the type hints in the class + definition. + """ if len(args) == len(attr_names): return args, kwargs if len(args) > len(attr_names): From e26611a048c3ae6fa4e763aeeb679ac3226b1adf Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 19 Dec 2023 11:33:13 +0100 Subject: [PATCH 02/13] MAINT: write test with unsympifiable class --- .cspell.json | 1 + .../decorator/test_unevaluated_expression.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/.cspell.json b/.cspell.json index 871b55b36..862df7d95 100644 --- a/.cspell.json +++ b/.cspell.json @@ -182,6 +182,7 @@ "sharey", "startswith", "suptitle", + "sympifiable", "sympified", "sympify", "symplot", diff --git a/tests/sympy/decorator/test_unevaluated_expression.py b/tests/sympy/decorator/test_unevaluated_expression.py index f04de40bf..3cbca27fb 100644 --- a/tests/sympy/decorator/test_unevaluated_expression.py +++ b/tests/sympy/decorator/test_unevaluated_expression.py @@ -29,6 +29,24 @@ def evaluate(self) -> sp.Expr: assert y_expr.doit() == 5**3 +def test_construction_non_sympy_attributes(): + class CannotBeSympified: ... + + @unevaluated_expression(implement_doit=False) + class MyExpr(sp.Expr): + sympifiable: Any + non_sympy: CannotBeSympified + + obj = CannotBeSympified() + expr = MyExpr( + sympifiable=3, + non_sympy=obj, + ) + assert expr.sympifiable is not 3 # noqa: F632 + assert expr.sympifiable is sp.Integer(3) + assert expr.non_sympy is obj + + def test_default_argument(): @unevaluated_expression class MyExpr(sp.Expr): From 284a9c735f44e68796de08a3fcf99054113336ee Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 19 Dec 2023 10:57:10 +0100 Subject: [PATCH 03/13] ENH: support non-sympy arguments in `unevaluated_expression()` --- src/ampform/sympy/_decorator.py | 36 +++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index 1251d561c..80a3ab03d 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -6,6 +6,8 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, overload import sympy as sp +from attrs import frozen +from sympy.utilities.exceptions import SymPyDeprecationWarning if sys.version_info < (3, 8): from typing_extensions import Protocol, TypedDict @@ -133,10 +135,10 @@ def _implement_new_method(cls: type[ExprClass]) -> type[ExprClass]: @functools.wraps(cls.__new__) @_insert_args_in_signature(attr_names, idx=1) def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: - positional_args, hints = _get_attribute_values(cls, attr_names, *args, **kwargs) - sympified_args = sp.sympify(positional_args) - expr = sp.Expr.__new__(cls, *sympified_args, **hints) - for name, value in zip(attr_names, sympified_args): + attr_values, hints = _get_attribute_values(cls, attr_names, *args, **kwargs) + converted_attr_values = _safe_sympify(*attr_values) + expr = sp.Expr.__new__(cls, *converted_attr_values.sympy, **hints) + for name, value in zip(attr_names, converted_attr_values.all_args): setattr(expr, name, value) if evaluate: return expr.evaluate() @@ -184,6 +186,32 @@ def _get_attribute_values( return tuple(attr_values), kwargs +def _safe_sympify(*args: Any) -> _ExprNewArumgents: + all_args = [] + sympy_args = [] + non_sympy_args = [] + for arg in args: + try: + converted_arg = sp.sympify(arg) + sympy_args.append(converted_arg) + except (TypeError, SymPyDeprecationWarning, sp.SympifyError): + converted_arg = arg + non_sympy_args.append(converted_arg) + all_args.append(converted_arg) + return _ExprNewArumgents( + all_args=tuple(all_args), + sympy=tuple(sympy_args), + non_sympy=tuple(non_sympy_args), + ) + + +@frozen +class _ExprNewArumgents: + all_args: tuple[Any, ...] + sympy: tuple[sp.Basic, ...] + non_sympy: tuple[Any, ...] + + class LatexMethod(Protocol): def __call__(self, printer: LatexPrinter, *args) -> str: ... From e38267c3e8648d3db437ce39c1ae899bd7ba224e Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 19 Dec 2023 12:09:15 +0100 Subject: [PATCH 04/13] ENH: implement hash for non-sympy attributes --- src/ampform/sympy/_decorator.py | 43 ++++++++++++++++--- .../decorator/test_unevaluated_expression.py | 13 ++++++ 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index 80a3ab03d..5aaee9693 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -3,7 +3,8 @@ import functools import inspect import sys -from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, overload +from inspect import isclass +from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterable, TypeVar, overload import sympy as sp from attrs import frozen @@ -23,8 +24,9 @@ from sympy.printing.latex import LatexPrinter ExprClass = TypeVar("ExprClass", bound=sp.Expr) -_P = ParamSpec("_P") -_T = TypeVar("_T") +P = ParamSpec("P") +T = TypeVar("T") +H = TypeVar("H", bound=Hashable) class SymPyAssumptions(TypedDict, total=False): @@ -140,14 +142,41 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: expr = sp.Expr.__new__(cls, *converted_attr_values.sympy, **hints) for name, value in zip(attr_names, converted_attr_values.all_args): setattr(expr, name, value) + expr._non_sympy_args = converted_attr_values.non_sympy if evaluate: return expr.evaluate() return expr + def _hashable_content(self) -> tuple: + hashable_content: tuple[sp.Basic, ...] = super( + sp.Expr, self + )._hashable_content() + if not self._non_sympy_args: + return hashable_content + remaining_content = (_get_hashable_object(arg) for arg in self._non_sympy_args) + return (*hashable_content, *remaining_content) + cls.__new__ = new_method # type: ignore[method-assign] + cls._hashable_content = _hashable_content # type: ignore[method-assign] return cls +@overload +def _get_hashable_object(obj: type) -> str: ... # type: ignore[overload-overlap] +@overload +def _get_hashable_object(obj: H) -> H: ... +@overload +def _get_hashable_object(obj: Any) -> str: ... +def _get_hashable_object(obj): + if isclass(obj): + return str(obj) + try: + hash(obj) + except TypeError: + return str(obj) + return obj + + def _get_attribute_values( cls: type[ExprClass], attr_names: tuple[str, ...], *args, **kwargs ) -> tuple[tuple, dict[str, Any]]: @@ -217,7 +246,7 @@ def __call__(self, printer: LatexPrinter, *args) -> str: ... @dataclass_transform() -def _implement_latex_repr(cls: type[_T]) -> type[_T]: +def _implement_latex_repr(cls: type[T]) -> type[T]: _latex_repr_: LatexMethod | str | None = getattr(cls, "_latex_repr_", None) if _latex_repr_ is None: msg = ( @@ -267,7 +296,7 @@ def _check_has_implementation(cls: type) -> None: def _insert_args_in_signature( new_params: Iterable[str] | None = None, idx: int = 0 -) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: +) -> Callable[[Callable[P, T]], Callable[P, T]]: if new_params is None: new_params = [] @@ -318,8 +347,8 @@ def _get_attribute_names(cls: type) -> tuple[str, ...]: @dataclass_transform() def _set_assumptions( **assumptions: Unpack[SymPyAssumptions], -) -> Callable[[type[_T]], type[_T]]: - def class_wrapper(cls: _T) -> _T: +) -> Callable[[type[T]], type[T]]: + def class_wrapper(cls: T) -> T: for assumption, value in assumptions.items(): setattr(cls, f"is_{assumption}", value) return cls diff --git a/tests/sympy/decorator/test_unevaluated_expression.py b/tests/sympy/decorator/test_unevaluated_expression.py index 3cbca27fb..dde856f18 100644 --- a/tests/sympy/decorator/test_unevaluated_expression.py +++ b/tests/sympy/decorator/test_unevaluated_expression.py @@ -101,6 +101,19 @@ def evaluate(self) -> sp.Expr: assert expr.default_return is half +def test_hashable_with_classes(): + class CannotBeSympified: ... + + @unevaluated_expression(implement_doit=False) + class MyExpr(sp.Expr): + x: Any + typ: type[CannotBeSympified] + + x = sp.Symbol("x") + expr = MyExpr(x, typ=CannotBeSympified) + assert expr._hashable_content() == (x, str(CannotBeSympified)) + + def test_no_implement_doit(): @unevaluated_expression(implement_doit=False) class Squared(sp.Expr): From 87f4f6e0d76fcbcd8f853ad56d17a090a2e3512b Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 19 Dec 2023 13:53:25 +0100 Subject: [PATCH 05/13] ENH: implement `xreplace()` method for non-sympy attributes --- src/ampform/sympy/_decorator.py | 31 ++++++++++++++++--- .../decorator/test_unevaluated_expression.py | 21 +++++++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index 5aaee9693..a1d5a19af 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -3,6 +3,7 @@ import functools import inspect import sys +from collections import abc from inspect import isclass from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterable, TypeVar, overload @@ -115,7 +116,7 @@ def decorator(cls: type[ExprClass]) -> type[ExprClass]: @dataclass_transform() -def _implement_new_method(cls: type[ExprClass]) -> type[ExprClass]: +def _implement_new_method(cls: type[ExprClass]) -> type[ExprClass]: # noqa: C901 """Implement the :meth:`__new__` method for dataclass-like SymPy expression classes. >>> @_implement_new_method @@ -142,22 +143,44 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: expr = sp.Expr.__new__(cls, *converted_attr_values.sympy, **hints) for name, value in zip(attr_names, converted_attr_values.all_args): setattr(expr, name, value) + expr._all_args = converted_attr_values.all_args expr._non_sympy_args = converted_attr_values.non_sympy if evaluate: return expr.evaluate() return expr def _hashable_content(self) -> tuple: - hashable_content: tuple[sp.Basic, ...] = super( - sp.Expr, self - )._hashable_content() + hashable_content = super(sp.Expr, self)._hashable_content() if not self._non_sympy_args: return hashable_content remaining_content = (_get_hashable_object(arg) for arg in self._non_sympy_args) return (*hashable_content, *remaining_content) + def _xreplace(self, rule) -> tuple[sp.Expr, bool]: + # https://github.com/sympy/sympy/blob/1.12/sympy/core/basic.py#L1233-L1253 + if self in rule: + return rule[self], True + if rule: + new_args = [] + hit = False + for arg in self._all_args: + if hasattr(arg, "_xreplace") and not isclass(arg): + replace_result, is_replaced = arg._xreplace(rule) + elif isinstance(rule, abc.Mapping): + is_replaced = bool(arg in rule) + replace_result = rule.get(arg, arg) + else: + replace_result = arg + is_replaced = False + new_args.append(replace_result) + hit |= is_replaced + if hit: + return self.func(*new_args), True + return self, False + cls.__new__ = new_method # type: ignore[method-assign] cls._hashable_content = _hashable_content # type: ignore[method-assign] + cls._xreplace = _xreplace # type: ignore[method-assign] return cls diff --git a/tests/sympy/decorator/test_unevaluated_expression.py b/tests/sympy/decorator/test_unevaluated_expression.py index dde856f18..1ab85b8fe 100644 --- a/tests/sympy/decorator/test_unevaluated_expression.py +++ b/tests/sympy/decorator/test_unevaluated_expression.py @@ -162,3 +162,24 @@ def evaluate(self) -> sp.Expr: assert isinstance(q_value.s, sp.Integer) assert isinstance(q_value.m1, sp.Float) assert isinstance(q_value.m2, sp.Float) + + +def test_xreplace_with_non_sympy_attributes(): + class Protocol: ... + + class Protocol1(Protocol): ... + + class Protocol2(Protocol): ... + + @unevaluated_expression(implement_doit=False) + class MyExpr(sp.Expr): + x: Any + protocol: type[Protocol] = Protocol1 + + x, y = sp.symbols("x y") + expr = MyExpr(x) + replaced_expr: MyExpr = expr.xreplace({x: y, Protocol1: Protocol2}) + assert replaced_expr.x is not x + assert replaced_expr.x is y + assert replaced_expr.protocol is not Protocol1 + assert replaced_expr.protocol is Protocol2 From d5fd54bd667507c3de863c670cd9d1cb3d164e2b Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 19 Dec 2023 14:11:21 +0100 Subject: [PATCH 06/13] BEHAVIOR: do not sympify `str` attributes --- src/ampform/sympy/_decorator.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index a1d5a19af..f562f94ab 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -243,11 +243,10 @@ def _safe_sympify(*args: Any) -> _ExprNewArumgents: sympy_args = [] non_sympy_args = [] for arg in args: - try: - converted_arg = sp.sympify(arg) + converted_arg, is_sympy = _try_sympify(arg) + if is_sympy: sympy_args.append(converted_arg) - except (TypeError, SymPyDeprecationWarning, sp.SympifyError): - converted_arg = arg + else: non_sympy_args.append(converted_arg) all_args.append(converted_arg) return _ExprNewArumgents( @@ -257,6 +256,15 @@ def _safe_sympify(*args: Any) -> _ExprNewArumgents: ) +def _try_sympify(obj) -> tuple[Any, bool]: + if isinstance(obj, str): + return obj, False + try: + return sp.sympify(obj), True + except (TypeError, SymPyDeprecationWarning, sp.SympifyError): + return obj, False + + @frozen class _ExprNewArumgents: all_args: tuple[Any, ...] From 677795d5314207917ae9dad9b78eb955aaa270e5 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 19 Dec 2023 14:36:29 +0100 Subject: [PATCH 07/13] MAINT: move method implementations to module level --- src/ampform/sympy/_decorator.py | 66 +++++++++++++++++---------------- 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index f562f94ab..c8c1a4fcb 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -116,7 +116,7 @@ def decorator(cls: type[ExprClass]) -> type[ExprClass]: @dataclass_transform() -def _implement_new_method(cls: type[ExprClass]) -> type[ExprClass]: # noqa: C901 +def _implement_new_method(cls: type[ExprClass]) -> type[ExprClass]: """Implement the :meth:`__new__` method for dataclass-like SymPy expression classes. >>> @_implement_new_method @@ -149,38 +149,9 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: return expr.evaluate() return expr - def _hashable_content(self) -> tuple: - hashable_content = super(sp.Expr, self)._hashable_content() - if not self._non_sympy_args: - return hashable_content - remaining_content = (_get_hashable_object(arg) for arg in self._non_sympy_args) - return (*hashable_content, *remaining_content) - - def _xreplace(self, rule) -> tuple[sp.Expr, bool]: - # https://github.com/sympy/sympy/blob/1.12/sympy/core/basic.py#L1233-L1253 - if self in rule: - return rule[self], True - if rule: - new_args = [] - hit = False - for arg in self._all_args: - if hasattr(arg, "_xreplace") and not isclass(arg): - replace_result, is_replaced = arg._xreplace(rule) - elif isinstance(rule, abc.Mapping): - is_replaced = bool(arg in rule) - replace_result = rule.get(arg, arg) - else: - replace_result = arg - is_replaced = False - new_args.append(replace_result) - hit |= is_replaced - if hit: - return self.func(*new_args), True - return self, False - cls.__new__ = new_method # type: ignore[method-assign] - cls._hashable_content = _hashable_content # type: ignore[method-assign] - cls._xreplace = _xreplace # type: ignore[method-assign] + cls._hashable_content = _hashable_content_method # type: ignore[method-assign] + cls._xreplace = _xreplace_method # type: ignore[method-assign] return cls @@ -385,3 +356,34 @@ def class_wrapper(cls: T) -> T: return cls return class_wrapper + + +def _hashable_content_method(self) -> tuple: + hashable_content = super(sp.Expr, self)._hashable_content() + if not self._non_sympy_args: + return hashable_content + remaining_content = (_get_hashable_object(arg) for arg in self._non_sympy_args) + return (*hashable_content, *remaining_content) + + +def _xreplace_method(self, rule) -> tuple[sp.Expr, bool]: + # https://github.com/sympy/sympy/blob/1.12/sympy/core/basic.py#L1233-L1253 + if self in rule: + return rule[self], True + if rule: + new_args = [] + hit = False + for arg in self._all_args: + if hasattr(arg, "_xreplace") and not isclass(arg): + replace_result, is_replaced = arg._xreplace(rule) + elif isinstance(rule, abc.Mapping): + is_replaced = bool(arg in rule) + replace_result = rule.get(arg, arg) + else: + replace_result = arg + is_replaced = False + new_args.append(replace_result) + hit |= is_replaced + if hit: + return self.func(*new_args), True + return self, False From 022b0f04d79ed7fc3b9a66eab76640bb916876ae Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 19 Dec 2023 14:49:15 +0100 Subject: [PATCH 08/13] ENH: implement `subs()` method for `unevaluated_expression` classes --- src/ampform/sympy/_decorator.py | 34 +++++++++++++++++++ .../decorator/test_unevaluated_expression.py | 18 +++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index c8c1a4fcb..8fd27398b 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -9,6 +9,7 @@ import sympy as sp from attrs import frozen +from sympy.core.basic import _aresame from sympy.utilities.exceptions import SymPyDeprecationWarning if sys.version_info < (3, 8): @@ -150,6 +151,7 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: return expr cls.__new__ = new_method # type: ignore[method-assign] + cls._eval_subs = _eval_subs_method # type: ignore[method-assign] cls._hashable_content = _hashable_content_method # type: ignore[method-assign] cls._xreplace = _xreplace_method # type: ignore[method-assign] return cls @@ -358,6 +360,38 @@ def class_wrapper(cls: T) -> T: return class_wrapper +def _eval_subs_method(self, old, new, **hints): + # https://github.com/sympy/sympy/blob/1.12/sympy/core/basic.py#L1117-L1147 + hit = False + substituted_attrs = list(self._all_args) + for i, old_attr in enumerate(substituted_attrs): + if not hasattr(old_attr, "_eval_subs"): + continue + if isclass(old_attr): + continue + new_attr = old_attr._subs(old, new, **hints) + if not _aresame(new_attr, old_attr): + hit = True + substituted_attrs[i] = new_attr + if hit: + rv = self.func(*substituted_attrs) + hack2 = hints.get("hack2", False) + if hack2 and self.is_Mul and not rv.is_Mul: # 2-arg hack + coefficient = sp.S.One + nonnumber = [] + for i in substituted_attrs: + if i.is_Number: + coefficient *= i + else: + nonnumber.append(i) + nonnumber = self.func(*nonnumber) + if coefficient is sp.S.One: + return nonnumber + return self.func(coefficient, nonnumber, evaluate=False) + return rv + return self + + def _hashable_content_method(self) -> tuple: hashable_content = super(sp.Expr, self)._hashable_content() if not self._non_sympy_args: diff --git a/tests/sympy/decorator/test_unevaluated_expression.py b/tests/sympy/decorator/test_unevaluated_expression.py index 1ab85b8fe..106a5108e 100644 --- a/tests/sympy/decorator/test_unevaluated_expression.py +++ b/tests/sympy/decorator/test_unevaluated_expression.py @@ -135,7 +135,7 @@ class MySqrt(sp.Expr): assert expr.is_complex # type: ignore[attr-defined] -def test_symbols_and_no_symbols(): +def test_non_symbols_construction(): @unevaluated_expression class BreakupMomentum(sp.Expr): s: Any @@ -164,6 +164,22 @@ def evaluate(self) -> sp.Expr: assert isinstance(q_value.m2, sp.Float) +def test_subs_with_non_sympy_attributes(): + class Protocol: ... + + @unevaluated_expression(implement_doit=False) + class MyExpr(sp.Expr): + x: Any + protocol: type[Protocol] = Protocol + + x, y = sp.symbols("x y") + expr = MyExpr(x) + replaced_expr: MyExpr = expr.subs(x, y) + assert replaced_expr.x is not x + assert replaced_expr.x is y + assert replaced_expr.protocol is Protocol + + def test_xreplace_with_non_sympy_attributes(): class Protocol: ... From 95d69190101f609f59928cf318d9e6445debabcf Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 19 Dec 2023 16:02:08 +0100 Subject: [PATCH 09/13] MAINT: remove redundant type ignore --- src/ampform/sympy/_decorator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index 8fd27398b..fb3b7e96c 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -71,8 +71,8 @@ def unevaluated_expression( ) -> Callable[[type[ExprClass]], type[ExprClass]]: ... -@dataclass_transform() # type: ignore[misc] -def unevaluated_expression( # type: ignore[misc] +@dataclass_transform() +def unevaluated_expression( cls: type[ExprClass] | None = None, *, implement_doit=True, **assumptions ): r"""Decorator for defining 'unevaluated' SymPy expressions. From f072bab32be6ffa128c2dcb091740864f52f40f4 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 19 Dec 2023 16:05:51 +0100 Subject: [PATCH 10/13] BREAK: rename `unevaluated_expression()` to `unevaluated()` The part `_expression` redundant, because the class already derives from `sympy.Expr` --- docs/usage/sympy.ipynb | 10 ++++----- src/ampform/dynamics/__init__.py | 4 ++-- src/ampform/kinematics/phasespace.py | 6 ++--- src/ampform/sympy/__init__.py | 4 ++-- src/ampform/sympy/_decorator.py | 10 ++++----- ...ated_expression.py => test_unevaluated.py} | 22 +++++++++---------- 6 files changed, 27 insertions(+), 29 deletions(-) rename tests/sympy/decorator/{test_unevaluated_expression.py => test_unevaluated.py} (90%) diff --git a/docs/usage/sympy.ipynb b/docs/usage/sympy.ipynb index f93e42676..8be5bfb5c 100644 --- a/docs/usage/sympy.ipynb +++ b/docs/usage/sympy.ipynb @@ -81,7 +81,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The {func}`.unevaluated_expression` decorator makes it easier to write classes that represent a mathematical function definition. It makes a class that derives from {class}`sp.Expr ` behave more like a {func}`~.dataclasses.dataclass` (see [PEP 861](https://peps.python.org/pep-0681)). All you have to do is:\n", + "The {func}`.unevaluated` decorator makes it easier to write classes that represent a mathematical function definition. It makes a class that derives from {class}`sp.Expr ` behave more like a {func}`~.dataclasses.dataclass` (see [PEP 861](https://peps.python.org/pep-0681)). All you have to do is:\n", "\n", "1. Specify the arguments the function requires.\n", "2. Specify how to render the 'unevaluated' or 'folded' form of the expression with a `_latex_repr_` string or method.\n", @@ -98,10 +98,10 @@ "source": [ "import sympy as sp\n", "\n", - "from ampform.sympy import unevaluated_expression\n", + "from ampform.sympy import unevaluated\n", "\n", "\n", - "@unevaluated_expression(real=False)\n", + "@unevaluated(real=False)\n", "class PhspFactorSWave(sp.Expr):\n", " s: sp.Symbol\n", " m1: sp.Symbol\n", @@ -119,7 +119,7 @@ " return 16 * sp.pi * sp.I * cm\n", "\n", "\n", - "@unevaluated_expression(real=False)\n", + "@unevaluated(real=False)\n", "class BreakupMomentum(sp.Expr):\n", " s: sp.Symbol\n", " m1: sp.Symbol\n", @@ -180,7 +180,7 @@ "from typing import Any, ClassVar\n", "\n", "\n", - "@unevaluated_expression\n", + "@unevaluated\n", "class FunkyPower(sp.Expr):\n", " x: Any\n", " m: int = 1\n", diff --git a/src/ampform/dynamics/__init__.py b/src/ampform/dynamics/__init__.py index 84d7422c6..eec76f35f 100644 --- a/src/ampform/dynamics/__init__.py +++ b/src/ampform/dynamics/__init__.py @@ -26,14 +26,14 @@ UnevaluatedExpression, determine_indices, implement_doit_method, - unevaluated_expression, + unevaluated, ) if TYPE_CHECKING: from sympy.printing.latex import LatexPrinter -@unevaluated_expression +@unevaluated class BlattWeisskopfSquared(sp.Expr): # cspell:ignore pychyGekoppeltePartialwellenanalyseAnnihilationen r"""Blatt-Weisskopf function :math:`B_L^2(z)`, up to :math:`L \leq 8`. diff --git a/src/ampform/kinematics/phasespace.py b/src/ampform/kinematics/phasespace.py index 8796a8812..4b98e70c7 100644 --- a/src/ampform/kinematics/phasespace.py +++ b/src/ampform/kinematics/phasespace.py @@ -9,10 +9,10 @@ import sympy as sp -from ampform.sympy import unevaluated_expression +from ampform.sympy import unevaluated -@unevaluated_expression +@unevaluated class Kibble(sp.Expr): """Kibble function for determining the phase space region.""" @@ -34,7 +34,7 @@ def evaluate(self) -> Kallen: ) -@unevaluated_expression +@unevaluated class Kallen(sp.Expr): """Källén function, used for computing break-up momenta.""" diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index 6945e31b7..581448068 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -1,6 +1,6 @@ """Tools that facilitate in building :mod:`sympy` expressions. -.. autodecorator:: unevaluated_expression +.. autodecorator:: unevaluated .. dropdown:: SymPy assumptions .. autodata:: ExprClass @@ -30,7 +30,7 @@ from ._decorator import ( ExprClass, # noqa: F401 # pyright: ignore[reportUnusedImport] SymPyAssumptions, # noqa: F401 # pyright: ignore[reportUnusedImport] - unevaluated_expression, # noqa: F401 # pyright: ignore[reportUnusedImport] + unevaluated, # noqa: F401 # pyright: ignore[reportUnusedImport] ) if TYPE_CHECKING: diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index fb3b7e96c..f23ce09ae 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -62,17 +62,15 @@ class SymPyAssumptions(TypedDict, total=False): @overload -def unevaluated_expression(cls: type[ExprClass]) -> type[ExprClass]: ... +def unevaluated(cls: type[ExprClass]) -> type[ExprClass]: ... @overload -def unevaluated_expression( +def unevaluated( *, implement_doit: bool = True, **assumptions: Unpack[SymPyAssumptions], ) -> Callable[[type[ExprClass]], type[ExprClass]]: ... - - @dataclass_transform() -def unevaluated_expression( +def unevaluated( cls: type[ExprClass] | None = None, *, implement_doit=True, **assumptions ): r"""Decorator for defining 'unevaluated' SymPy expressions. @@ -80,7 +78,7 @@ def unevaluated_expression( Unevaluated expressions are handy for defining large expressions that consist of several sub-definitions. - >>> @unevaluated_expression + >>> @unevaluated ... class MyExpr(sp.Expr): ... x: sp.Symbol ... y: sp.Symbol diff --git a/tests/sympy/decorator/test_unevaluated_expression.py b/tests/sympy/decorator/test_unevaluated.py similarity index 90% rename from tests/sympy/decorator/test_unevaluated_expression.py rename to tests/sympy/decorator/test_unevaluated.py index 106a5108e..aee580495 100644 --- a/tests/sympy/decorator/test_unevaluated_expression.py +++ b/tests/sympy/decorator/test_unevaluated.py @@ -5,11 +5,11 @@ import sympy as sp -from ampform.sympy._decorator import unevaluated_expression +from ampform.sympy._decorator import unevaluated def test_classvar_behavior(): - @unevaluated_expression + @unevaluated class MyExpr(sp.Expr): x: float m: ClassVar[int] = 2 @@ -32,7 +32,7 @@ def evaluate(self) -> sp.Expr: def test_construction_non_sympy_attributes(): class CannotBeSympified: ... - @unevaluated_expression(implement_doit=False) + @unevaluated(implement_doit=False) class MyExpr(sp.Expr): sympifiable: Any non_sympy: CannotBeSympified @@ -48,7 +48,7 @@ class MyExpr(sp.Expr): def test_default_argument(): - @unevaluated_expression + @unevaluated class MyExpr(sp.Expr): x: Any m: int = 2 @@ -65,7 +65,7 @@ def evaluate(self) -> sp.Expr: def test_default_argument_with_classvar(): - @unevaluated_expression + @unevaluated class FunkyPower(sp.Expr): x: Any m: int = 1 @@ -104,7 +104,7 @@ def evaluate(self) -> sp.Expr: def test_hashable_with_classes(): class CannotBeSympified: ... - @unevaluated_expression(implement_doit=False) + @unevaluated(implement_doit=False) class MyExpr(sp.Expr): x: Any typ: type[CannotBeSympified] @@ -115,7 +115,7 @@ class MyExpr(sp.Expr): def test_no_implement_doit(): - @unevaluated_expression(implement_doit=False) + @unevaluated(implement_doit=False) class Squared(sp.Expr): x: Any @@ -126,7 +126,7 @@ def evaluate(self) -> sp.Expr: assert str(sqrt) == "Squared(2)" assert str(sqrt.doit()) == "Squared(2)" - @unevaluated_expression(complex=True, implement_doit=False) + @unevaluated(complex=True, implement_doit=False) class MySqrt(sp.Expr): x: Any @@ -136,7 +136,7 @@ class MySqrt(sp.Expr): def test_non_symbols_construction(): - @unevaluated_expression + @unevaluated class BreakupMomentum(sp.Expr): s: Any m1: Any @@ -167,7 +167,7 @@ def evaluate(self) -> sp.Expr: def test_subs_with_non_sympy_attributes(): class Protocol: ... - @unevaluated_expression(implement_doit=False) + @unevaluated(implement_doit=False) class MyExpr(sp.Expr): x: Any protocol: type[Protocol] = Protocol @@ -187,7 +187,7 @@ class Protocol1(Protocol): ... class Protocol2(Protocol): ... - @unevaluated_expression(implement_doit=False) + @unevaluated(implement_doit=False) class MyExpr(sp.Expr): x: Any protocol: type[Protocol] = Protocol1 From 2ddb323fdb30b979eeaa5c9e21ff858e0d98f2a0 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 19 Dec 2023 16:26:07 +0100 Subject: [PATCH 11/13] MAINT: test class var definition without `ClassVar` --- docs/usage/sympy.ipynb | 3 ++- tests/sympy/decorator/test_unevaluated.py | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/usage/sympy.ipynb b/docs/usage/sympy.ipynb index 8be5bfb5c..80369b35d 100644 --- a/docs/usage/sympy.ipynb +++ b/docs/usage/sympy.ipynb @@ -166,7 +166,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Class variables and default arguments to instance arguments are also supported:" + "Class variables and default arguments to instance arguments are also supported. They can either be indicated with {class}`typing.ClassVar` or by not providing a type hint:" ] }, { @@ -185,6 +185,7 @@ " x: Any\n", " m: int = 1\n", " default_return: ClassVar[sp.Expr | None] = None\n", + " class_name = \"my name\"\n", " _latex_repr_ = R\"f_{{{m}}}\\left({x}\\right)\"\n", "\n", " def evaluate(self) -> sp.Expr | None:\n", diff --git a/tests/sympy/decorator/test_unevaluated.py b/tests/sympy/decorator/test_unevaluated.py index aee580495..ce9bfdba1 100644 --- a/tests/sympy/decorator/test_unevaluated.py +++ b/tests/sympy/decorator/test_unevaluated.py @@ -13,6 +13,7 @@ def test_classvar_behavior(): class MyExpr(sp.Expr): x: float m: ClassVar[int] = 2 + class_name = "MyExpr" def evaluate(self) -> sp.Expr: return self.x**self.m # type: ignore[return-value] @@ -24,9 +25,15 @@ def evaluate(self) -> sp.Expr: y_expr = MyExpr(5) assert x_expr.doit() == 4**2 assert y_expr.doit() == 5**2 + assert x_expr.class_name == "MyExpr" + assert y_expr.class_name == "MyExpr" MyExpr.m = 3 + new_name = "different name" + MyExpr.class_name = new_name assert x_expr.doit() == 4**3 assert y_expr.doit() == 5**3 + assert x_expr.class_name == new_name + assert y_expr.class_name == new_name def test_construction_non_sympy_attributes(): From 2f7532d58f06d042dbe0f4399e93da38d3cb7511 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Wed, 20 Dec 2023 16:45:38 +0100 Subject: [PATCH 12/13] MAINT: put `TypeVar` definitions under `TYPE_CHECKING` --- src/ampform/sympy/_decorator.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index f23ce09ae..138ec3698 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -18,17 +18,23 @@ from typing import Protocol, TypedDict if sys.version_info < (3, 11): - from typing_extensions import ParamSpec, Unpack, dataclass_transform + from typing_extensions import dataclass_transform else: - from typing import ParamSpec, Unpack, dataclass_transform + from typing import dataclass_transform if TYPE_CHECKING: from sympy.printing.latex import LatexPrinter + if sys.version_info < (3, 11): + from typing_extensions import ParamSpec, Unpack + else: + from typing import ParamSpec, Unpack + + H = TypeVar("H", bound=Hashable) + P = ParamSpec("P") + T = TypeVar("T") + ExprClass = TypeVar("ExprClass", bound=sp.Expr) -P = ParamSpec("P") -T = TypeVar("T") -H = TypeVar("H", bound=Hashable) class SymPyAssumptions(TypedDict, total=False): From d603956c2b1951b737403b8a948d8f14914f53eb Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Thu, 21 Dec 2023 16:04:12 +0100 Subject: [PATCH 13/13] Kick CI