diff --git a/docs/conf.py b/docs/conf.py index 82270a326..9f3e3a73c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -49,17 +49,16 @@ add_module_names = False api_github_repo = f"{ORGANIZATION}/{REPO_NAME}" api_target_substitutions: dict[str, str | tuple[str, str]] = { - "T": "TypeVar", "BuilderReturnType": ("obj", "ampform.dynamics.builder.BuilderReturnType"), - "DecoratedClass": ("obj", "ampform.sympy.DecoratedClass"), - "DecoratedExpr": ("obj", "ampform.sympy.DecoratedExpr"), - "ExprClass": "ampform.sympy.ExprClass", + "DecoratedClass": ("obj", "ampform.sympy.deprecated.DecoratedClass"), + "DecoratedExpr": ("obj", "ampform.sympy.deprecated.DecoratedExpr"), "FourMomenta": ("obj", "ampform.kinematics.lorentz.FourMomenta"), "FourMomentumSymbol": ("obj", "ampform.kinematics.lorentz.FourMomentumSymbol"), "InteractionProperties": "qrules.quantum_numbers.InteractionProperties", "LatexPrinter": "sympy.printing.printer.Printer", "Literal[(-1, 1)]": "typing.Literal", "Literal[-1, 1]": "typing.Literal", + "NumPyPrintable": ("class", "ampform.sympy.NumPyPrintable"), "NumPyPrinter": "sympy.printing.printer.Printer", "ParameterValue": ("obj", "ampform.helicity.ParameterValue"), "Particle": "qrules.particle.Particle", @@ -67,12 +66,11 @@ "Slider": ("obj", "symplot.Slider"), "State": "qrules.transition.State", "StateTransition": "qrules.transition.StateTransition", + "T": "TypeVar", "Topology": "qrules.topology.Topology", "WignerD": "sympy.physics.quantum.spin.WignerD", "ampform.helicity._T": "typing.TypeVar", "ampform.sympy._decorator.ExprClass": ("obj", "ampform.sympy.ExprClass"), - "ampform.sympy._decorator.SymPyAssumptions": "ampform.sympy.SymPyAssumptions", - "an object providing a view on D's values": "typing.ValuesView", "sp.Basic": "sympy.core.basic.Basic", "sp.Expr": "sympy.core.expr.Expr", "sp.Float": "sympy.core.numbers.Float", @@ -289,7 +287,9 @@ nb_output_stderr = "remove" nitpick_ignore = [ ("py:class", "ArraySum"), + ("py:class", "ExprClass"), ("py:class", "MatrixMultiplication"), + ("py:class", "ampform.sympy._decorator.SymPyAssumptions"), ] nitpicky = True primary_domain = "py" diff --git a/pyproject.toml b/pyproject.toml index 9923f8665..73c33e22a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -249,6 +249,7 @@ filterwarnings = [ "error", "ignore:.*invalid value encountered in sqrt.*:RuntimeWarning", "ignore:.*is deprecated and slated for removal in Python 3.14:DeprecationWarning", + "ignore:.*the @ampform.sympy.unevaluated_expression decorator instead( with commutative=True)?:DeprecationWarning", "ignore:Passing a schema to Validator.iter_errors is deprecated.*:DeprecationWarning", "ignore:The .* argument to NotebookFile is deprecated.*:pytest.PytestRemovedIn8Warning", "ignore:The distutils package is deprecated.*:DeprecationWarning", diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index 958d189df..e0c9db2de 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -10,7 +10,6 @@ """ -# cspell:ignore mhash from __future__ import annotations import functools @@ -23,7 +22,7 @@ from abc import abstractmethod from os.path import abspath, dirname, expanduser from textwrap import dedent -from typing import TYPE_CHECKING, Callable, Iterable, Sequence, SupportsFloat, TypeVar +from typing import TYPE_CHECKING, Iterable, Sequence, SupportsFloat import sympy as sp from sympy.printing.conventions import split_super_sub @@ -35,6 +34,13 @@ argument, # noqa: F401 # pyright: ignore[reportUnusedImport] unevaluated, # noqa: F401 # pyright: ignore[reportUnusedImport] ) +from .deprecated import ( + UnevaluatedExpression, # noqa: F401 # pyright: ignore[reportUnusedImport] + create_expression, # noqa: F401 # pyright: ignore[reportUnusedImport] + implement_doit_method, # noqa: F401 # pyright: ignore[reportUnusedImport] + implement_expr, # pyright: ignore[reportUnusedImport] # noqa: F401 + make_commutative, # pyright: ignore[reportUnusedImport] # noqa: F401 +) if TYPE_CHECKING: from sympy.printing.latex import LatexPrinter @@ -43,133 +49,13 @@ _LOGGER = logging.getLogger(__name__) -class UnevaluatedExpression(sp.Expr): - """Base class for expression classes with an :meth:`evaluate` method. - - Deriving from `~sympy.core.expr.Expr` allows us to keep expression trees condense - before unfolding them with their `~sympy.core.basic.Basic.doit` method. This allows - us to: - - 1. condense the LaTeX representation of an expression tree by providing a custom - :meth:`_latex` method. - 2. overwrite its printer methods (see `NumPyPrintable` and e.g. - :doc:`compwa-org:report/001`). - - The `UnevaluatedExpression` base class makes implementations of its derived classes - more secure by enforcing the developer to provide implementations for these methods, - so that SymPy mechanisms work correctly. Decorators like :func:`implement_expr` and - :func:`implement_doit_method` provide convenient means to implement the missing - methods. - - .. autolink-preface:: - - import sympy as sp - from ampform.sympy import UnevaluatedExpression, create_expression - - .. automethod:: __new__ - .. automethod:: evaluate - .. automethod:: _latex - """ - - # https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L74-L77 - __slots__: tuple[str] = ("_name",) - _name: str | None - """Optional instance attribute that can be used in LaTeX representations.""" - - def __new__( - cls: type[DecoratedClass], - *args, - name: str | None = None, - **hints, - ) -> DecoratedClass: - """Constructor for a class derived from `UnevaluatedExpression`. - - This :meth:`~object.__new__` method correctly sets the - `~sympy.core.basic.Basic.args`, assumptions etc. Overwrite it in order to - further specify its signature. The function :func:`create_expression` can be - used in its implementation, like so: - - >>> class MyExpression(UnevaluatedExpression): - ... def __new__( - ... cls, x: sp.Symbol, y: sp.Symbol, n: int, **hints - ... ) -> "MyExpression": - ... return create_expression(cls, x, y, n, **hints) - ... - ... def evaluate(self) -> sp.Expr: - ... x, y, n = self.args - ... return (x + y)**n - ... - >>> x, y = sp.symbols("x y") - >>> expr = MyExpression(x, y, n=3) - >>> expr - MyExpression(x, y, 3) - >>> expr.evaluate() - (x + y)**3 - """ - # https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L113-L119 - obj = object.__new__(cls) - obj._args = args - obj._assumptions = cls.default_assumptions # type: ignore[attr-defined] - obj._mhash = None - obj._name = name - return obj - - def __getnewargs_ex__(self) -> tuple[tuple, dict]: - # Pickling support, see - # https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L124-L126 - args = tuple(self.args) - kwargs = {"name": self._name} - return args, kwargs - - def _hashable_content(self) -> tuple: - # https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165 - # name is converted to string because unstable hash for None - return (*super()._hashable_content(), str(self._name)) - - @abstractmethod - def evaluate(self) -> sp.Expr: - """Evaluate and 'unfold' this `UnevaluatedExpression` by one level. - - >>> from ampform.dynamics import BreakupMomentumSquared - >>> s, m1, m2 = sp.symbols("s m1 m2") - >>> expr = BreakupMomentumSquared(s, m1, m2) - >>> expr - BreakupMomentumSquared(s, m1, m2) - >>> expr.evaluate() - (s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s) - >>> expr.doit(deep=False) - (s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s) - - .. note:: When decorating this class with :func:`implement_doit_method`, - its :meth:`evaluate` method is equivalent to - :meth:`~sympy.core.basic.Basic.doit` with :code:`deep=False`. - """ - - def _latex(self, printer: LatexPrinter, *args) -> str: - r"""Provide a mathematical Latex representation for pretty printing. - - >>> from ampform.dynamics import BreakupMomentumSquared - >>> s, m1 = sp.symbols("s m1") - >>> expr = BreakupMomentumSquared(s, m1, m1) - >>> print(sp.latex(expr)) - q^2\left(s\right) - >>> print(sp.latex(expr.doit())) - - m_{1}^{2} + \frac{s}{4} - """ - args = tuple(map(printer._print, self.args)) - name = type(self).__name__ - if self._name is not None: - name = self._name - return f"{name}{args}" - - class NumPyPrintable(sp.Expr): r"""`~sympy.core.expr.Expr` class that can lambdify to NumPy code. - This interface for classes that derive from `sympy.Expr ` - enforce the implementation of a :meth:`_numpycode` method in case the class does not - correctly :func:`~sympy.utilities.lambdify.lambdify` to NumPy code. For more info on - SymPy printers, see :doc:`sympy:modules/printing`. + This interface is for classes that derive from `sympy.Expr ` + and that require a :meth:`_numpycode` method in case the class does not correctly + :func:`~sympy.utilities.lambdify.lambdify` to NumPy code. For more info on SymPy + printers, see :doc:`sympy:modules/printing`. Several computational frameworks try to converge their interface to that of NumPy. See for instance `TensorFlow's NumPy API @@ -179,9 +65,9 @@ class NumPyPrintable(sp.Expr): :func:`~sympy.utilities.lambdify.lambdify` SymPy expressions to these different backends with the same lambdification code. - .. note:: This interface differs from `UnevaluatedExpression` in that it **should - not** implement an :meth:`.evaluate` (and therefore a - :meth:`~sympy.core.basic.Basic.doit`) method. + .. warning:: If you decorate this class with :func:`unevaluated`, you usually want + to do so with :code:`implement_doit=False`, because you do not want the class + to be 'unfolded' with :meth:`~sympy.core.basic.Basic.doit` before lambdification. .. warning:: The implemented :meth:`_numpycode` method should countain as little @@ -201,117 +87,6 @@ def _numpycode(self, printer: NumPyPrinter, *args) -> str: """Lambdify this `NumPyPrintable` class to NumPy code.""" -DecoratedClass = TypeVar("DecoratedClass", bound=UnevaluatedExpression) -"""`~typing.TypeVar` for decorators like :func:`implement_doit_method`.""" - - -def implement_expr( - n_args: int, -) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]: - """Decorator for classes that derive from `UnevaluatedExpression`. - - Implement a :meth:`~object.__new__` and :meth:`~sympy.core.basic.Basic.doit` method - for a class that derives from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). - """ - - def decorator( - decorated_class: type[DecoratedClass], - ) -> type[DecoratedClass]: - decorated_class = implement_new_method(n_args)(decorated_class) - return implement_doit_method(decorated_class) - - return decorator - - -def implement_new_method( - n_args: int, -) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]: - """Implement :meth:`UnevaluatedExpression.__new__` on a derived class. - - Implement a :meth:`~object.__new__` method for a class that derives from - `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). - """ - - def decorator( - decorated_class: type[DecoratedClass], - ) -> type[DecoratedClass]: - def new_method( - cls: type[DecoratedClass], - *args: sp.Symbol, - evaluate: bool = False, - **hints, - ) -> DecoratedClass: - if len(args) != n_args: - msg = f"{n_args} parameters expected, got {len(args)}" - raise ValueError(msg) - args = sp.sympify(args) - expr = UnevaluatedExpression.__new__(cls, *args) - if evaluate: - return expr.evaluate() # type: ignore[return-value] - return expr - - decorated_class.__new__ = new_method # type: ignore[assignment] - return decorated_class - - return decorator - - -def implement_doit_method( - decorated_class: type[DecoratedClass], -) -> type[DecoratedClass]: - """Implement ``doit()`` method for an `UnevaluatedExpression` class. - - Implement a :meth:`~sympy.core.basic.Basic.doit` method for a class that derives - from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). A - :meth:`~sympy.core.basic.Basic.doit` method is an extension of an - :meth:`~.UnevaluatedExpression.evaluate` method in the sense that it can work - recursively on deeper expression trees. - """ - - @functools.wraps(decorated_class.doit) # type: ignore[attr-defined] - def doit_method(self: UnevaluatedExpression, deep: bool = True) -> sp.Expr: - expr = self.evaluate() - if deep: - return expr.doit() - return expr - - decorated_class.doit = doit_method # type: ignore[assignment] - return decorated_class - - -DecoratedExpr = TypeVar("DecoratedExpr", bound=sp.Expr) -"""`~typing.TypeVar` for decorators like :func:`make_commutative`.""" - - -def make_commutative( - decorated_class: type[DecoratedExpr], -) -> type[DecoratedExpr]: - """Set commutative and 'extended real' assumptions on expression class. - - .. seealso:: :doc:`sympy:guides/assumptions` - """ - decorated_class.is_commutative = True # type: ignore[attr-defined] - decorated_class.is_extended_real = True # type: ignore[attr-defined] - return decorated_class - - -def create_expression( - cls: type[DecoratedExpr], - *args, - evaluate: bool = False, - name: str | None = None, - **kwargs, -) -> DecoratedExpr: - """Helper function for implementing `UnevaluatedExpression.__new__`.""" - args = sp.sympify(args) - if issubclass(cls, UnevaluatedExpression): - expr = UnevaluatedExpression.__new__(cls, *args, name=name, **kwargs) - if evaluate: - return expr.evaluate() # type: ignore[return-value] - return expr # type: ignore[return-value] - return sp.Expr.__new__(cls, *args, **kwargs) # type: ignore[return-value] - - def create_symbol_matrix(name: str, m: int, n: int) -> sp.MutableDenseMatrix: """Create a `~sympy.matrices.dense.Matrix` with symbols as elements. @@ -332,8 +107,7 @@ def create_symbol_matrix(name: str, m: int, n: int) -> sp.MutableDenseMatrix: return sp.Matrix([[symbol[i, j] for j in range(n)] for i in range(m)]) -@implement_doit_method -class PoolSum(UnevaluatedExpression): +class PoolSum(sp.Expr): r"""Sum over indices where the values are taken from a domain set. >>> i, j, m, n = sp.symbols("i j m n") @@ -352,6 +126,7 @@ def __new__( cls, expression, *indices: tuple[sp.Symbol, Iterable[sp.Basic]], + evaluate: bool = False, **hints, ) -> PoolSum: converted_indices = [] @@ -361,7 +136,11 @@ def __new__( msg = f"No values provided for index {idx_symbol}" raise ValueError(msg) converted_indices.append((idx_symbol, values)) - return create_expression(cls, expression, *converted_indices, **hints) + args = sp.sympify((expression, *converted_indices)) + expr: PoolSum = sp.Expr.__new__(cls, *args, **hints) + if evaluate: + return expr.evaluate() # type: ignore[return-value] + return expr @property def expression(self) -> sp.Expr: @@ -375,6 +154,12 @@ def indices(self) -> list[tuple[sp.Symbol, tuple[sp.Float, ...]]]: def free_symbols(self) -> set[sp.Basic]: return super().free_symbols - {s for s, _ in self.indices} + def doit(self, deep: bool = True) -> sp.Expr: # type: ignore[override] + expr = self.evaluate() + if deep: + return expr.doit() + return expr + def evaluate(self) -> sp.Expr: indices = {symbol: tuple(values) for symbol, values in self.indices} return sp.Add(*[ diff --git a/src/ampform/sympy/_array_expressions.py b/src/ampform/sympy/_array_expressions.py index c16835c63..22128596c 100644 --- a/src/ampform/sympy/_array_expressions.py +++ b/src/ampform/sympy/_array_expressions.py @@ -26,8 +26,6 @@ get_shape, ) -from ampform.sympy import create_expression, make_commutative - if TYPE_CHECKING: from sympy.printing.numpy import NumPyPrinter @@ -254,7 +252,8 @@ class ArraySum(sp.Expr): precedence = PRECEDENCE["Add"] def __new__(cls, *terms: sp.Basic, **hints) -> ArraySum: - return create_expression(cls, *terms, **hints) + terms = sp.sympify(terms) + return sp.Expr.__new__(cls, *terms, **hints) @property def terms(self) -> tuple[sp.Basic, ...]: @@ -305,13 +304,15 @@ def _strip_subscript_superscript(symbol: sp.Basic) -> str: return name -@make_commutative class ArrayAxisSum(sp.Expr): + is_commutative = True + def __new__(cls, array: sp.Expr, axis: int | None = None, **hints) -> ArrayAxisSum: if axis is not None and not isinstance(axis, (int, sp.Integer)): msg = "Only single digits allowed for axis" raise TypeError(msg) - return create_expression(cls, array, axis, **hints) + args = sp.sympify((array, axis)) + return sp.Expr.__new__(cls, *args, **hints) @property def array(self) -> sp.Expr: @@ -346,7 +347,8 @@ class ArrayMultiplication(sp.Expr): """ def __new__(cls, *tensors: sp.Basic, **hints) -> ArrayMultiplication: - return create_expression(cls, *tensors, **hints) + tensors = sp.sympify(tensors) + return sp.Expr.__new__(cls, *tensors, **hints) @property def tensors(self) -> list[sp.Expr]: @@ -399,7 +401,8 @@ class MatrixMultiplication(sp.Expr): """ def __new__(cls, *tensors: sp.Basic, **hints) -> MatrixMultiplication: - return create_expression(cls, *tensors, **hints) + tensors = sp.sympify(tensors) + return sp.Expr.__new__(cls, *tensors, **hints) @property def tensors(self) -> tuple[sp.Basic, ...]: diff --git a/src/ampform/sympy/deprecated.py b/src/ampform/sympy/deprecated.py new file mode 100644 index 000000000..85e402dae --- /dev/null +++ b/src/ampform/sympy/deprecated.py @@ -0,0 +1,286 @@ +"""Deprecated classes and functions for constructing unevaluated expressions. + +.. deprecated:: 0.15.0 +""" + +from __future__ import annotations + +import functools +from abc import abstractmethod +from typing import TYPE_CHECKING, Callable, TypeVar +from warnings import warn + +import sympy as sp + +if TYPE_CHECKING: + from sympy.printing.latex import LatexPrinter + + +class UnevaluatedExpression(sp.Expr): + """Base class for expression classes with an :meth:`evaluate` method. + + Deriving from `~sympy.core.expr.Expr` allows us to keep expression trees condense + before unfolding them with their `~sympy.core.basic.Basic.doit` method. This allows + us to: + + 1. condense the LaTeX representation of an expression tree by providing a custom + :meth:`_latex` method. + 2. overwrite its printer methods (see `.NumPyPrintable` and e.g. + :doc:`compwa-org:report/001`). + + The `UnevaluatedExpression` base class makes implementations of its derived classes + more secure by enforcing the developer to provide implementations for these methods, + so that SymPy mechanisms work correctly. Decorators like :func:`implement_expr` and + :func:`implement_doit_method` provide convenient means to implement the missing + methods. + + .. autolink-preface:: + + import sympy as sp + from ampform.sympy import UnevaluatedExpression, create_expression + + .. automethod:: __new__ + .. automethod:: evaluate + .. automethod:: _latex + """ + + # https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L74-L77 + __slots__: tuple[str] = ("_name",) + _name: str | None + """Optional instance attribute that can be used in LaTeX representations.""" + + def __init_subclass__(cls, **kwargs): + warn( + f"{cls.__name__} is deprecated, use the" + " @ampform.sympy.unevaluated_expression decorator instead", + category=DeprecationWarning, + stacklevel=1, + ) + super().__init_subclass__(**kwargs) + + def __new__( + cls: type[DecoratedClass], + *args, + name: str | None = None, + **hints, + ) -> DecoratedClass: + """Constructor for a class derived from `UnevaluatedExpression`. + + This :meth:`~object.__new__` method correctly sets the + `~sympy.core.basic.Basic.args`, assumptions etc. Overwrite it in order to + further specify its signature. The function :func:`create_expression` can be + used in its implementation, like so: + + >>> class MyExpression(UnevaluatedExpression): + ... def __new__( + ... cls, x: sp.Symbol, y: sp.Symbol, n: int, **hints + ... ) -> "MyExpression": + ... return create_expression(cls, x, y, n, **hints) + ... + ... def evaluate(self) -> sp.Expr: + ... x, y, n = self.args + ... return (x + y)**n + ... + >>> x, y = sp.symbols("x y") + >>> expr = MyExpression(x, y, n=3) + >>> expr + MyExpression(x, y, 3) + >>> expr.evaluate() + (x + y)**3 + """ + # https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L113-L119 + obj = object.__new__(cls) + obj._args = args + obj._assumptions = cls.default_assumptions # type: ignore[attr-defined] + obj._mhash = None # cspell:ignore mhash + obj._name = name + return obj + + def __getnewargs_ex__(self) -> tuple[tuple, dict]: + # Pickling support, see + # https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L124-L126 + args = tuple(self.args) + kwargs = {"name": self._name} + return args, kwargs + + def _hashable_content(self) -> tuple: + # https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165 + # name is converted to string because unstable hash for None + return (*super()._hashable_content(), str(self._name)) + + @abstractmethod + def evaluate(self) -> sp.Expr: + """Evaluate and 'unfold' this `UnevaluatedExpression` by one level. + + >>> from ampform.dynamics import BreakupMomentumSquared + >>> s, m1, m2 = sp.symbols("s m1 m2") + >>> expr = BreakupMomentumSquared(s, m1, m2) + >>> expr + BreakupMomentumSquared(s, m1, m2) + >>> expr.evaluate() + (s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s) + >>> expr.doit(deep=False) + (s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s) + + .. note:: When decorating this class with :func:`implement_doit_method`, + its :meth:`evaluate` method is equivalent to + :meth:`~sympy.core.basic.Basic.doit` with :code:`deep=False`. + """ + + def _latex(self, printer: LatexPrinter, *args) -> str: + r"""Provide a mathematical Latex representation for pretty printing. + + >>> from ampform.dynamics import BreakupMomentumSquared + >>> s, m1 = sp.symbols("s m1") + >>> expr = BreakupMomentumSquared(s, m1, m1) + >>> print(sp.latex(expr)) + q^2\left(s\right) + >>> print(sp.latex(expr.doit())) + - m_{1}^{2} + \frac{s}{4} + """ + args = tuple(map(printer._print, self.args)) + name = type(self).__name__ + if self._name is not None: + name = self._name + return f"{name}{args}" + + +DecoratedClass = TypeVar("DecoratedClass", bound=UnevaluatedExpression) +"""`~typing.TypeVar` for decorators like :func:`implement_doit_method`.""" + + +def implement_expr( + n_args: int, +) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]: + """Decorator for classes that derive from `UnevaluatedExpression`. + + Implement a :meth:`~object.__new__` and :meth:`~sympy.core.basic.Basic.doit` method + for a class that derives from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). + """ + warn( + "@implement_expr is deprecated, use the @ampform.sympy.unevaluated_expression" + " decorator instead", + category=DeprecationWarning, + stacklevel=1, + ) + + def decorator( + decorated_class: type[DecoratedClass], + ) -> type[DecoratedClass]: + decorated_class = implement_new_method(n_args)(decorated_class) + return implement_doit_method(decorated_class) + + return decorator + + +def implement_new_method( + n_args: int, +) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]: + """Implement :meth:`UnevaluatedExpression.__new__` on a derived class. + + Implement a :meth:`~object.__new__` method for a class that derives from + `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). + """ + warn( + "@implement_new_method is deprecated, use the" + " @ampform.sympy.unevaluated_expression decorator instead", + category=DeprecationWarning, + stacklevel=1, + ) + + def decorator( + decorated_class: type[DecoratedClass], + ) -> type[DecoratedClass]: + def new_method( + cls: type[DecoratedClass], + *args: sp.Symbol, + evaluate: bool = False, + **hints, + ) -> DecoratedClass: + if len(args) != n_args: + msg = f"{n_args} parameters expected, got {len(args)}" + raise ValueError(msg) + args = sp.sympify(args) + expr = UnevaluatedExpression.__new__(cls, *args) + if evaluate: + return expr.evaluate() # type: ignore[return-value] + return expr + + decorated_class.__new__ = new_method # type: ignore[assignment] + return decorated_class + + return decorator + + +def implement_doit_method( + decorated_class: type[DecoratedClass], +) -> type[DecoratedClass]: + """Implement ``doit()`` method for an `UnevaluatedExpression` class. + + Implement a :meth:`~sympy.core.basic.Basic.doit` method for a class that derives + from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). A + :meth:`~sympy.core.basic.Basic.doit` method is an extension of an + :meth:`~.UnevaluatedExpression.evaluate` method in the sense that it can work + recursively on deeper expression trees. + """ + warn( + "@implement_doit_method is deprecated, use the" + " @ampform.sympy.unevaluated_expression decorator instead", + category=DeprecationWarning, + stacklevel=1, + ) + + @functools.wraps(decorated_class.doit) # type: ignore[attr-defined] + def doit_method(self: UnevaluatedExpression, deep: bool = True) -> sp.Expr: + expr = self.evaluate() + if deep: + return expr.doit() + return expr + + decorated_class.doit = doit_method # type: ignore[assignment] + return decorated_class + + +DecoratedExpr = TypeVar("DecoratedExpr", bound=sp.Expr) +"""`~typing.TypeVar` for decorators like :func:`make_commutative`.""" + + +def make_commutative( + decorated_class: type[DecoratedExpr], +) -> type[DecoratedExpr]: + """Set commutative and 'extended real' assumptions on expression class. + + .. seealso:: :doc:`sympy:guides/assumptions` + """ + warn( + "@make_commutative is deprecated, use the @ampform.sympy.unevaluated_expression" + " decorator instead with commutative=True", + category=DeprecationWarning, + stacklevel=1, + ) + decorated_class.is_commutative = True # type: ignore[attr-defined] + decorated_class.is_extended_real = True # type: ignore[attr-defined] + return decorated_class + + +def create_expression( + cls: type[DecoratedExpr], + *args, + evaluate: bool = False, + name: str | None = None, + **kwargs, +) -> DecoratedExpr: + """Helper function for implementing `UnevaluatedExpression.__new__`.""" + warn( + "create_expression() is deprecated, construct the class with the" + " @ampform.sympy.unevaluated_expression decorator instead", + category=DeprecationWarning, + stacklevel=1, + ) + args = sp.sympify(args) + if issubclass(cls, UnevaluatedExpression): + expr = UnevaluatedExpression.__new__(cls, *args, name=name, **kwargs) + if evaluate: + return expr.evaluate() # type: ignore[return-value] + return expr # type: ignore[return-value] + return sp.Expr.__new__(cls, *args, **kwargs) # type: ignore[return-value] diff --git a/src/ampform/sympy/math.py b/src/ampform/sympy/math.py index 553177e14..d050472c2 100644 --- a/src/ampform/sympy/math.py +++ b/src/ampform/sympy/math.py @@ -8,7 +8,7 @@ import sympy as sp from sympy.plotting.experimental_lambdify import Lambdifier -from ampform.sympy import NumPyPrintable, create_expression, make_commutative +from ampform.sympy import NumPyPrintable if TYPE_CHECKING: from sympy.printing.numpy import NumPyPrinter @@ -16,7 +16,6 @@ from sympy.printing.pycode import PythonCodePrinter -@make_commutative class ComplexSqrt(NumPyPrintable): """Square root that returns positive imaginary values for negative input. @@ -27,13 +26,17 @@ class ComplexSqrt(NumPyPrintable): :func:`~sympy.utilities.lambdify.lambdify` printer. """ + is_commutative = True + is_extended_real = True + @overload def __new__(cls, x: sp.Number, *args, **kwargs) -> sp.Expr: ... # type: ignore[misc] @overload def __new__(cls, x: sp.Expr, *args, **kwargs) -> ComplexSqrt: ... def __new__(cls, x, *args, **kwargs): x = sp.sympify(x) - expr = create_expression(cls, x, *args, **kwargs) + args = sp.sympify((x, *args)) + expr: ComplexSqrt = sp.Expr.__new__(cls, *args, **kwargs) # type: ignore[annotation-unchecked] if isinstance(x, sp.Number): return expr.get_definition() return expr @@ -60,13 +63,7 @@ def __print_complex(self, printer: Printer) -> str: return printer._print(expr) def get_definition(self) -> sp.Piecewise: - """Get a symbolic definition for this expression class. - - .. note:: This class is `.NumPyPrintable`, so should not have an - :meth:`~.UnevaluatedExpression.evaluate` method (in order to block - :meth:`~sympy.core.basic.Basic.doit`). This method serves as an equivalent - to that. - """ + """Get a symbolic definition for this expression class.""" x: sp.Expr = self.args[0] # type: ignore[assignment] return sp.Piecewise( (sp.I * sp.sqrt(-x), x < 0), diff --git a/tests/dynamics/test_sympy.py b/tests/dynamics/test_deprecated.py similarity index 100% rename from tests/dynamics/test_sympy.py rename to tests/dynamics/test_deprecated.py