diff --git a/.cspell.json b/.cspell.json index 871b55b36..862df7d95 100644 --- a/.cspell.json +++ b/.cspell.json @@ -182,6 +182,7 @@ "sharey", "startswith", "suptitle", + "sympifiable", "sympified", "sympify", "symplot", diff --git a/docs/usage/sympy.ipynb b/docs/usage/sympy.ipynb index f93e42676..80369b35d 100644 --- a/docs/usage/sympy.ipynb +++ b/docs/usage/sympy.ipynb @@ -81,7 +81,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The {func}`.unevaluated_expression` decorator makes it easier to write classes that represent a mathematical function definition. It makes a class that derives from {class}`sp.Expr ` behave more like a {func}`~.dataclasses.dataclass` (see [PEP 861](https://peps.python.org/pep-0681)). All you have to do is:\n", + "The {func}`.unevaluated` decorator makes it easier to write classes that represent a mathematical function definition. It makes a class that derives from {class}`sp.Expr ` behave more like a {func}`~.dataclasses.dataclass` (see [PEP 861](https://peps.python.org/pep-0681)). All you have to do is:\n", "\n", "1. Specify the arguments the function requires.\n", "2. Specify how to render the 'unevaluated' or 'folded' form of the expression with a `_latex_repr_` string or method.\n", @@ -98,10 +98,10 @@ "source": [ "import sympy as sp\n", "\n", - "from ampform.sympy import unevaluated_expression\n", + "from ampform.sympy import unevaluated\n", "\n", "\n", - "@unevaluated_expression(real=False)\n", + "@unevaluated(real=False)\n", "class PhspFactorSWave(sp.Expr):\n", " s: sp.Symbol\n", " m1: sp.Symbol\n", @@ -119,7 +119,7 @@ " return 16 * sp.pi * sp.I * cm\n", "\n", "\n", - "@unevaluated_expression(real=False)\n", + "@unevaluated(real=False)\n", "class BreakupMomentum(sp.Expr):\n", " s: sp.Symbol\n", " m1: sp.Symbol\n", @@ -166,7 +166,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Class variables and default arguments to instance arguments are also supported:" + "Class variables and default arguments to instance arguments are also supported. They can either be indicated with {class}`typing.ClassVar` or by not providing a type hint:" ] }, { @@ -180,11 +180,12 @@ "from typing import Any, ClassVar\n", "\n", "\n", - "@unevaluated_expression\n", + "@unevaluated\n", "class FunkyPower(sp.Expr):\n", " x: Any\n", " m: int = 1\n", " default_return: ClassVar[sp.Expr | None] = None\n", + " class_name = \"my name\"\n", " _latex_repr_ = R\"f_{{{m}}}\\left({x}\\right)\"\n", "\n", " def evaluate(self) -> sp.Expr | None:\n", diff --git a/src/ampform/dynamics/__init__.py b/src/ampform/dynamics/__init__.py index 84d7422c6..eec76f35f 100644 --- a/src/ampform/dynamics/__init__.py +++ b/src/ampform/dynamics/__init__.py @@ -26,14 +26,14 @@ UnevaluatedExpression, determine_indices, implement_doit_method, - unevaluated_expression, + unevaluated, ) if TYPE_CHECKING: from sympy.printing.latex import LatexPrinter -@unevaluated_expression +@unevaluated class BlattWeisskopfSquared(sp.Expr): # cspell:ignore pychyGekoppeltePartialwellenanalyseAnnihilationen r"""Blatt-Weisskopf function :math:`B_L^2(z)`, up to :math:`L \leq 8`. diff --git a/src/ampform/kinematics/phasespace.py b/src/ampform/kinematics/phasespace.py index 8796a8812..4b98e70c7 100644 --- a/src/ampform/kinematics/phasespace.py +++ b/src/ampform/kinematics/phasespace.py @@ -9,10 +9,10 @@ import sympy as sp -from ampform.sympy import unevaluated_expression +from ampform.sympy import unevaluated -@unevaluated_expression +@unevaluated class Kibble(sp.Expr): """Kibble function for determining the phase space region.""" @@ -34,7 +34,7 @@ def evaluate(self) -> Kallen: ) -@unevaluated_expression +@unevaluated class Kallen(sp.Expr): """Källén function, used for computing break-up momenta.""" diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index 6945e31b7..581448068 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -1,6 +1,6 @@ """Tools that facilitate in building :mod:`sympy` expressions. -.. autodecorator:: unevaluated_expression +.. autodecorator:: unevaluated .. dropdown:: SymPy assumptions .. autodata:: ExprClass @@ -30,7 +30,7 @@ from ._decorator import ( ExprClass, # noqa: F401 # pyright: ignore[reportUnusedImport] SymPyAssumptions, # noqa: F401 # pyright: ignore[reportUnusedImport] - unevaluated_expression, # noqa: F401 # pyright: ignore[reportUnusedImport] + unevaluated, # noqa: F401 # pyright: ignore[reportUnusedImport] ) if TYPE_CHECKING: diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index a3e42ce2c..138ec3698 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -3,9 +3,14 @@ import functools import inspect import sys -from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, overload +from collections import abc +from inspect import isclass +from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterable, TypeVar, overload import sympy as sp +from attrs import frozen +from sympy.core.basic import _aresame +from sympy.utilities.exceptions import SymPyDeprecationWarning if sys.version_info < (3, 8): from typing_extensions import Protocol, TypedDict @@ -13,16 +18,23 @@ from typing import Protocol, TypedDict if sys.version_info < (3, 11): - from typing_extensions import ParamSpec, Unpack, dataclass_transform + from typing_extensions import dataclass_transform else: - from typing import ParamSpec, Unpack, dataclass_transform + from typing import dataclass_transform if TYPE_CHECKING: from sympy.printing.latex import LatexPrinter + if sys.version_info < (3, 11): + from typing_extensions import ParamSpec, Unpack + else: + from typing import ParamSpec, Unpack + + H = TypeVar("H", bound=Hashable) + P = ParamSpec("P") + T = TypeVar("T") + ExprClass = TypeVar("ExprClass", bound=sp.Expr) -_P = ParamSpec("_P") -_T = TypeVar("_T") class SymPyAssumptions(TypedDict, total=False): @@ -56,17 +68,15 @@ class SymPyAssumptions(TypedDict, total=False): @overload -def unevaluated_expression(cls: type[ExprClass]) -> type[ExprClass]: ... +def unevaluated(cls: type[ExprClass]) -> type[ExprClass]: ... @overload -def unevaluated_expression( +def unevaluated( *, implement_doit: bool = True, **assumptions: Unpack[SymPyAssumptions], ) -> Callable[[type[ExprClass]], type[ExprClass]]: ... - - -@dataclass_transform() # type: ignore[misc] -def unevaluated_expression( # type: ignore[misc] +@dataclass_transform() +def unevaluated( cls: type[ExprClass] | None = None, *, implement_doit=True, **assumptions ): r"""Decorator for defining 'unevaluated' SymPy expressions. @@ -74,7 +84,7 @@ def unevaluated_expression( # type: ignore[misc] Unevaluated expressions are handy for defining large expressions that consist of several sub-definitions. - >>> @unevaluated_expression + >>> @unevaluated ... class MyExpr(sp.Expr): ... x: sp.Symbol ... y: sp.Symbol @@ -133,22 +143,54 @@ def _implement_new_method(cls: type[ExprClass]) -> type[ExprClass]: @functools.wraps(cls.__new__) @_insert_args_in_signature(attr_names, idx=1) def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: - positional_args, hints = _get_attribute_values(cls, attr_names, *args, **kwargs) - sympified_args = sp.sympify(positional_args) - expr = sp.Expr.__new__(cls, *sympified_args, **hints) - for name, value in zip(attr_names, sympified_args): + attr_values, hints = _get_attribute_values(cls, attr_names, *args, **kwargs) + converted_attr_values = _safe_sympify(*attr_values) + expr = sp.Expr.__new__(cls, *converted_attr_values.sympy, **hints) + for name, value in zip(attr_names, converted_attr_values.all_args): setattr(expr, name, value) + expr._all_args = converted_attr_values.all_args + expr._non_sympy_args = converted_attr_values.non_sympy if evaluate: return expr.evaluate() return expr cls.__new__ = new_method # type: ignore[method-assign] + cls._eval_subs = _eval_subs_method # type: ignore[method-assign] + cls._hashable_content = _hashable_content_method # type: ignore[method-assign] + cls._xreplace = _xreplace_method # type: ignore[method-assign] return cls +@overload +def _get_hashable_object(obj: type) -> str: ... # type: ignore[overload-overlap] +@overload +def _get_hashable_object(obj: H) -> H: ... +@overload +def _get_hashable_object(obj: Any) -> str: ... +def _get_hashable_object(obj): + if isclass(obj): + return str(obj) + try: + hash(obj) + except TypeError: + return str(obj) + return obj + + def _get_attribute_values( cls: type[ExprClass], attr_names: tuple[str, ...], *args, **kwargs ) -> tuple[tuple, dict[str, Any]]: + """Extract the attribute values from the constructor arguments. + + Returns a `tuple` of: + + 1. the extracted, ordered attributes as requested by :code:`attr_names`, + 2. a `dict` of remaining keyword arguments that can be used hints for the + constructed :class:`sp.Expr` instance. + + An attempt is made to get any missing attributes from the type hints in the class + definition. + """ if len(args) == len(attr_names): return args, kwargs if len(args) > len(attr_names): @@ -173,12 +215,46 @@ def _get_attribute_values( return tuple(attr_values), kwargs +def _safe_sympify(*args: Any) -> _ExprNewArumgents: + all_args = [] + sympy_args = [] + non_sympy_args = [] + for arg in args: + converted_arg, is_sympy = _try_sympify(arg) + if is_sympy: + sympy_args.append(converted_arg) + else: + non_sympy_args.append(converted_arg) + all_args.append(converted_arg) + return _ExprNewArumgents( + all_args=tuple(all_args), + sympy=tuple(sympy_args), + non_sympy=tuple(non_sympy_args), + ) + + +def _try_sympify(obj) -> tuple[Any, bool]: + if isinstance(obj, str): + return obj, False + try: + return sp.sympify(obj), True + except (TypeError, SymPyDeprecationWarning, sp.SympifyError): + return obj, False + + +@frozen +class _ExprNewArumgents: + all_args: tuple[Any, ...] + sympy: tuple[sp.Basic, ...] + non_sympy: tuple[Any, ...] + + class LatexMethod(Protocol): def __call__(self, printer: LatexPrinter, *args) -> str: ... @dataclass_transform() -def _implement_latex_repr(cls: type[_T]) -> type[_T]: +def _implement_latex_repr(cls: type[T]) -> type[T]: _latex_repr_: LatexMethod | str | None = getattr(cls, "_latex_repr_", None) if _latex_repr_ is None: msg = ( @@ -228,7 +304,7 @@ def _check_has_implementation(cls: type) -> None: def _insert_args_in_signature( new_params: Iterable[str] | None = None, idx: int = 0 -) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: +) -> Callable[[Callable[P, T]], Callable[P, T]]: if new_params is None: new_params = [] @@ -279,10 +355,73 @@ def _get_attribute_names(cls: type) -> tuple[str, ...]: @dataclass_transform() def _set_assumptions( **assumptions: Unpack[SymPyAssumptions], -) -> Callable[[type[_T]], type[_T]]: - def class_wrapper(cls: _T) -> _T: +) -> Callable[[type[T]], type[T]]: + def class_wrapper(cls: T) -> T: for assumption, value in assumptions.items(): setattr(cls, f"is_{assumption}", value) return cls return class_wrapper + + +def _eval_subs_method(self, old, new, **hints): + # https://github.com/sympy/sympy/blob/1.12/sympy/core/basic.py#L1117-L1147 + hit = False + substituted_attrs = list(self._all_args) + for i, old_attr in enumerate(substituted_attrs): + if not hasattr(old_attr, "_eval_subs"): + continue + if isclass(old_attr): + continue + new_attr = old_attr._subs(old, new, **hints) + if not _aresame(new_attr, old_attr): + hit = True + substituted_attrs[i] = new_attr + if hit: + rv = self.func(*substituted_attrs) + hack2 = hints.get("hack2", False) + if hack2 and self.is_Mul and not rv.is_Mul: # 2-arg hack + coefficient = sp.S.One + nonnumber = [] + for i in substituted_attrs: + if i.is_Number: + coefficient *= i + else: + nonnumber.append(i) + nonnumber = self.func(*nonnumber) + if coefficient is sp.S.One: + return nonnumber + return self.func(coefficient, nonnumber, evaluate=False) + return rv + return self + + +def _hashable_content_method(self) -> tuple: + hashable_content = super(sp.Expr, self)._hashable_content() + if not self._non_sympy_args: + return hashable_content + remaining_content = (_get_hashable_object(arg) for arg in self._non_sympy_args) + return (*hashable_content, *remaining_content) + + +def _xreplace_method(self, rule) -> tuple[sp.Expr, bool]: + # https://github.com/sympy/sympy/blob/1.12/sympy/core/basic.py#L1233-L1253 + if self in rule: + return rule[self], True + if rule: + new_args = [] + hit = False + for arg in self._all_args: + if hasattr(arg, "_xreplace") and not isclass(arg): + replace_result, is_replaced = arg._xreplace(rule) + elif isinstance(rule, abc.Mapping): + is_replaced = bool(arg in rule) + replace_result = rule.get(arg, arg) + else: + replace_result = arg + is_replaced = False + new_args.append(replace_result) + hit |= is_replaced + if hit: + return self.func(*new_args), True + return self, False diff --git a/tests/sympy/decorator/test_unevaluated_expression.py b/tests/sympy/decorator/test_unevaluated.py similarity index 58% rename from tests/sympy/decorator/test_unevaluated_expression.py rename to tests/sympy/decorator/test_unevaluated.py index f04de40bf..ce9bfdba1 100644 --- a/tests/sympy/decorator/test_unevaluated_expression.py +++ b/tests/sympy/decorator/test_unevaluated.py @@ -5,14 +5,15 @@ import sympy as sp -from ampform.sympy._decorator import unevaluated_expression +from ampform.sympy._decorator import unevaluated def test_classvar_behavior(): - @unevaluated_expression + @unevaluated class MyExpr(sp.Expr): x: float m: ClassVar[int] = 2 + class_name = "MyExpr" def evaluate(self) -> sp.Expr: return self.x**self.m # type: ignore[return-value] @@ -24,13 +25,37 @@ def evaluate(self) -> sp.Expr: y_expr = MyExpr(5) assert x_expr.doit() == 4**2 assert y_expr.doit() == 5**2 + assert x_expr.class_name == "MyExpr" + assert y_expr.class_name == "MyExpr" MyExpr.m = 3 + new_name = "different name" + MyExpr.class_name = new_name assert x_expr.doit() == 4**3 assert y_expr.doit() == 5**3 + assert x_expr.class_name == new_name + assert y_expr.class_name == new_name + + +def test_construction_non_sympy_attributes(): + class CannotBeSympified: ... + + @unevaluated(implement_doit=False) + class MyExpr(sp.Expr): + sympifiable: Any + non_sympy: CannotBeSympified + + obj = CannotBeSympified() + expr = MyExpr( + sympifiable=3, + non_sympy=obj, + ) + assert expr.sympifiable is not 3 # noqa: F632 + assert expr.sympifiable is sp.Integer(3) + assert expr.non_sympy is obj def test_default_argument(): - @unevaluated_expression + @unevaluated class MyExpr(sp.Expr): x: Any m: int = 2 @@ -47,7 +72,7 @@ def evaluate(self) -> sp.Expr: def test_default_argument_with_classvar(): - @unevaluated_expression + @unevaluated class FunkyPower(sp.Expr): x: Any m: int = 1 @@ -83,8 +108,21 @@ def evaluate(self) -> sp.Expr: assert expr.default_return is half +def test_hashable_with_classes(): + class CannotBeSympified: ... + + @unevaluated(implement_doit=False) + class MyExpr(sp.Expr): + x: Any + typ: type[CannotBeSympified] + + x = sp.Symbol("x") + expr = MyExpr(x, typ=CannotBeSympified) + assert expr._hashable_content() == (x, str(CannotBeSympified)) + + def test_no_implement_doit(): - @unevaluated_expression(implement_doit=False) + @unevaluated(implement_doit=False) class Squared(sp.Expr): x: Any @@ -95,7 +133,7 @@ def evaluate(self) -> sp.Expr: assert str(sqrt) == "Squared(2)" assert str(sqrt.doit()) == "Squared(2)" - @unevaluated_expression(complex=True, implement_doit=False) + @unevaluated(complex=True, implement_doit=False) class MySqrt(sp.Expr): x: Any @@ -104,8 +142,8 @@ class MySqrt(sp.Expr): assert expr.is_complex # type: ignore[attr-defined] -def test_symbols_and_no_symbols(): - @unevaluated_expression +def test_non_symbols_construction(): + @unevaluated class BreakupMomentum(sp.Expr): s: Any m1: Any @@ -131,3 +169,40 @@ def evaluate(self) -> sp.Expr: assert isinstance(q_value.s, sp.Integer) assert isinstance(q_value.m1, sp.Float) assert isinstance(q_value.m2, sp.Float) + + +def test_subs_with_non_sympy_attributes(): + class Protocol: ... + + @unevaluated(implement_doit=False) + class MyExpr(sp.Expr): + x: Any + protocol: type[Protocol] = Protocol + + x, y = sp.symbols("x y") + expr = MyExpr(x) + replaced_expr: MyExpr = expr.subs(x, y) + assert replaced_expr.x is not x + assert replaced_expr.x is y + assert replaced_expr.protocol is Protocol + + +def test_xreplace_with_non_sympy_attributes(): + class Protocol: ... + + class Protocol1(Protocol): ... + + class Protocol2(Protocol): ... + + @unevaluated(implement_doit=False) + class MyExpr(sp.Expr): + x: Any + protocol: type[Protocol] = Protocol1 + + x, y = sp.symbols("x y") + expr = MyExpr(x) + replaced_expr: MyExpr = expr.xreplace({x: y, Protocol1: Protocol2}) + assert replaced_expr.x is not x + assert replaced_expr.x is y + assert replaced_expr.protocol is not Protocol1 + assert replaced_expr.protocol is Protocol2