Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BREAK: deprecate UnevaluatedExpression templates #383

Merged
merged 3 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,30 +49,28 @@
add_module_names = False
api_github_repo = f"{ORGANIZATION}/{REPO_NAME}"
api_target_substitutions: dict[str, str | tuple[str, str]] = {
"T": "TypeVar",
"BuilderReturnType": ("obj", "ampform.dynamics.builder.BuilderReturnType"),
"DecoratedClass": ("obj", "ampform.sympy.DecoratedClass"),
"DecoratedExpr": ("obj", "ampform.sympy.DecoratedExpr"),
"ExprClass": "ampform.sympy.ExprClass",
"DecoratedClass": ("obj", "ampform.sympy.deprecated.DecoratedClass"),
"DecoratedExpr": ("obj", "ampform.sympy.deprecated.DecoratedExpr"),
"FourMomenta": ("obj", "ampform.kinematics.lorentz.FourMomenta"),
"FourMomentumSymbol": ("obj", "ampform.kinematics.lorentz.FourMomentumSymbol"),
"InteractionProperties": "qrules.quantum_numbers.InteractionProperties",
"LatexPrinter": "sympy.printing.printer.Printer",
"Literal[(-1, 1)]": "typing.Literal",
"Literal[-1, 1]": "typing.Literal",
"NumPyPrintable": ("class", "ampform.sympy.NumPyPrintable"),
"NumPyPrinter": "sympy.printing.printer.Printer",
"ParameterValue": ("obj", "ampform.helicity.ParameterValue"),
"Particle": "qrules.particle.Particle",
"ReactionInfo": "qrules.transition.ReactionInfo",
"Slider": ("obj", "symplot.Slider"),
"State": "qrules.transition.State",
"StateTransition": "qrules.transition.StateTransition",
"T": "TypeVar",
"Topology": "qrules.topology.Topology",
"WignerD": "sympy.physics.quantum.spin.WignerD",
"ampform.helicity._T": "typing.TypeVar",
"ampform.sympy._decorator.ExprClass": ("obj", "ampform.sympy.ExprClass"),
"ampform.sympy._decorator.SymPyAssumptions": "ampform.sympy.SymPyAssumptions",
"an object providing a view on D's values": "typing.ValuesView",
"sp.Basic": "sympy.core.basic.Basic",
"sp.Expr": "sympy.core.expr.Expr",
"sp.Float": "sympy.core.numbers.Float",
Expand Down Expand Up @@ -289,7 +287,9 @@
nb_output_stderr = "remove"
nitpick_ignore = [
("py:class", "ArraySum"),
("py:class", "ExprClass"),
("py:class", "MatrixMultiplication"),
("py:class", "ampform.sympy._decorator.SymPyAssumptions"),
]
nitpicky = True
primary_domain = "py"
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ filterwarnings = [
"error",
"ignore:.*invalid value encountered in sqrt.*:RuntimeWarning",
"ignore:.*is deprecated and slated for removal in Python 3.14:DeprecationWarning",
"ignore:.*the @ampform.sympy.unevaluated_expression decorator instead( with commutative=True)?:DeprecationWarning",
"ignore:Passing a schema to Validator.iter_errors is deprecated.*:DeprecationWarning",
"ignore:The .* argument to NotebookFile is deprecated.*:pytest.PytestRemovedIn8Warning",
"ignore:The distutils package is deprecated.*:DeprecationWarning",
Expand Down
271 changes: 28 additions & 243 deletions src/ampform/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

"""

# cspell:ignore mhash
from __future__ import annotations

import functools
Expand All @@ -23,7 +22,7 @@
from abc import abstractmethod
from os.path import abspath, dirname, expanduser
from textwrap import dedent
from typing import TYPE_CHECKING, Callable, Iterable, Sequence, SupportsFloat, TypeVar
from typing import TYPE_CHECKING, Iterable, Sequence, SupportsFloat

import sympy as sp
from sympy.printing.conventions import split_super_sub
Expand All @@ -35,6 +34,13 @@
argument, # noqa: F401 # pyright: ignore[reportUnusedImport]
unevaluated, # noqa: F401 # pyright: ignore[reportUnusedImport]
)
from .deprecated import (
UnevaluatedExpression, # noqa: F401 # pyright: ignore[reportUnusedImport]
create_expression, # noqa: F401 # pyright: ignore[reportUnusedImport]
implement_doit_method, # noqa: F401 # pyright: ignore[reportUnusedImport]
implement_expr, # pyright: ignore[reportUnusedImport] # noqa: F401
make_commutative, # pyright: ignore[reportUnusedImport] # noqa: F401
)

if TYPE_CHECKING:
from sympy.printing.latex import LatexPrinter
Expand All @@ -43,133 +49,13 @@
_LOGGER = logging.getLogger(__name__)


class UnevaluatedExpression(sp.Expr):
"""Base class for expression classes with an :meth:`evaluate` method.

Deriving from `~sympy.core.expr.Expr` allows us to keep expression trees condense
before unfolding them with their `~sympy.core.basic.Basic.doit` method. This allows
us to:

1. condense the LaTeX representation of an expression tree by providing a custom
:meth:`_latex` method.
2. overwrite its printer methods (see `NumPyPrintable` and e.g.
:doc:`compwa-org:report/001`).

The `UnevaluatedExpression` base class makes implementations of its derived classes
more secure by enforcing the developer to provide implementations for these methods,
so that SymPy mechanisms work correctly. Decorators like :func:`implement_expr` and
:func:`implement_doit_method` provide convenient means to implement the missing
methods.

.. autolink-preface::

import sympy as sp
from ampform.sympy import UnevaluatedExpression, create_expression

.. automethod:: __new__
.. automethod:: evaluate
.. automethod:: _latex
"""

# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L74-L77
__slots__: tuple[str] = ("_name",)
_name: str | None
"""Optional instance attribute that can be used in LaTeX representations."""

def __new__(
cls: type[DecoratedClass],
*args,
name: str | None = None,
**hints,
) -> DecoratedClass:
"""Constructor for a class derived from `UnevaluatedExpression`.

This :meth:`~object.__new__` method correctly sets the
`~sympy.core.basic.Basic.args`, assumptions etc. Overwrite it in order to
further specify its signature. The function :func:`create_expression` can be
used in its implementation, like so:

>>> class MyExpression(UnevaluatedExpression):
... def __new__(
... cls, x: sp.Symbol, y: sp.Symbol, n: int, **hints
... ) -> "MyExpression":
... return create_expression(cls, x, y, n, **hints)
...
... def evaluate(self) -> sp.Expr:
... x, y, n = self.args
... return (x + y)**n
...
>>> x, y = sp.symbols("x y")
>>> expr = MyExpression(x, y, n=3)
>>> expr
MyExpression(x, y, 3)
>>> expr.evaluate()
(x + y)**3
"""
# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L113-L119
obj = object.__new__(cls)
obj._args = args
obj._assumptions = cls.default_assumptions # type: ignore[attr-defined]
obj._mhash = None
obj._name = name
return obj

def __getnewargs_ex__(self) -> tuple[tuple, dict]:
# Pickling support, see
# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L124-L126
args = tuple(self.args)
kwargs = {"name": self._name}
return args, kwargs

def _hashable_content(self) -> tuple:
# https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165
# name is converted to string because unstable hash for None
return (*super()._hashable_content(), str(self._name))

@abstractmethod
def evaluate(self) -> sp.Expr:
"""Evaluate and 'unfold' this `UnevaluatedExpression` by one level.

>>> from ampform.dynamics import BreakupMomentumSquared
>>> s, m1, m2 = sp.symbols("s m1 m2")
>>> expr = BreakupMomentumSquared(s, m1, m2)
>>> expr
BreakupMomentumSquared(s, m1, m2)
>>> expr.evaluate()
(s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s)
>>> expr.doit(deep=False)
(s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s)

.. note:: When decorating this class with :func:`implement_doit_method`,
its :meth:`evaluate` method is equivalent to
:meth:`~sympy.core.basic.Basic.doit` with :code:`deep=False`.
"""

def _latex(self, printer: LatexPrinter, *args) -> str:
r"""Provide a mathematical Latex representation for pretty printing.

>>> from ampform.dynamics import BreakupMomentumSquared
>>> s, m1 = sp.symbols("s m1")
>>> expr = BreakupMomentumSquared(s, m1, m1)
>>> print(sp.latex(expr))
q^2\left(s\right)
>>> print(sp.latex(expr.doit()))
- m_{1}^{2} + \frac{s}{4}
"""
args = tuple(map(printer._print, self.args))
name = type(self).__name__
if self._name is not None:
name = self._name
return f"{name}{args}"


class NumPyPrintable(sp.Expr):
r"""`~sympy.core.expr.Expr` class that can lambdify to NumPy code.

This interface for classes that derive from `sympy.Expr <sympy.core.expr.Expr>`
enforce the implementation of a :meth:`_numpycode` method in case the class does not
correctly :func:`~sympy.utilities.lambdify.lambdify` to NumPy code. For more info on
SymPy printers, see :doc:`sympy:modules/printing`.
This interface is for classes that derive from `sympy.Expr <sympy.core.expr.Expr>`
and that require a :meth:`_numpycode` method in case the class does not correctly
:func:`~sympy.utilities.lambdify.lambdify` to NumPy code. For more info on SymPy
printers, see :doc:`sympy:modules/printing`.

Several computational frameworks try to converge their interface to that of NumPy.
See for instance `TensorFlow's NumPy API
Expand All @@ -179,9 +65,9 @@
:func:`~sympy.utilities.lambdify.lambdify` SymPy expressions to these different
backends with the same lambdification code.

.. note:: This interface differs from `UnevaluatedExpression` in that it **should
not** implement an :meth:`.evaluate` (and therefore a
:meth:`~sympy.core.basic.Basic.doit`) method.
.. warning:: If you decorate this class with :func:`unevaluated`, you usually want
to do so with :code:`implement_doit=False`, because you do not want the class
to be 'unfolded' with :meth:`~sympy.core.basic.Basic.doit` before lambdification.


.. warning:: The implemented :meth:`_numpycode` method should countain as little
Expand All @@ -201,117 +87,6 @@
"""Lambdify this `NumPyPrintable` class to NumPy code."""


DecoratedClass = TypeVar("DecoratedClass", bound=UnevaluatedExpression)
"""`~typing.TypeVar` for decorators like :func:`implement_doit_method`."""


def implement_expr(
n_args: int,
) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]:
"""Decorator for classes that derive from `UnevaluatedExpression`.

Implement a :meth:`~object.__new__` and :meth:`~sympy.core.basic.Basic.doit` method
for a class that derives from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`).
"""

def decorator(
decorated_class: type[DecoratedClass],
) -> type[DecoratedClass]:
decorated_class = implement_new_method(n_args)(decorated_class)
return implement_doit_method(decorated_class)

return decorator


def implement_new_method(
n_args: int,
) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]:
"""Implement :meth:`UnevaluatedExpression.__new__` on a derived class.

Implement a :meth:`~object.__new__` method for a class that derives from
`~sympy.core.expr.Expr` (via `UnevaluatedExpression`).
"""

def decorator(
decorated_class: type[DecoratedClass],
) -> type[DecoratedClass]:
def new_method(
cls: type[DecoratedClass],
*args: sp.Symbol,
evaluate: bool = False,
**hints,
) -> DecoratedClass:
if len(args) != n_args:
msg = f"{n_args} parameters expected, got {len(args)}"
raise ValueError(msg)
args = sp.sympify(args)
expr = UnevaluatedExpression.__new__(cls, *args)
if evaluate:
return expr.evaluate() # type: ignore[return-value]
return expr

decorated_class.__new__ = new_method # type: ignore[assignment]
return decorated_class

return decorator


def implement_doit_method(
decorated_class: type[DecoratedClass],
) -> type[DecoratedClass]:
"""Implement ``doit()`` method for an `UnevaluatedExpression` class.

Implement a :meth:`~sympy.core.basic.Basic.doit` method for a class that derives
from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). A
:meth:`~sympy.core.basic.Basic.doit` method is an extension of an
:meth:`~.UnevaluatedExpression.evaluate` method in the sense that it can work
recursively on deeper expression trees.
"""

@functools.wraps(decorated_class.doit) # type: ignore[attr-defined]
def doit_method(self: UnevaluatedExpression, deep: bool = True) -> sp.Expr:
expr = self.evaluate()
if deep:
return expr.doit()
return expr

decorated_class.doit = doit_method # type: ignore[assignment]
return decorated_class


DecoratedExpr = TypeVar("DecoratedExpr", bound=sp.Expr)
"""`~typing.TypeVar` for decorators like :func:`make_commutative`."""


def make_commutative(
decorated_class: type[DecoratedExpr],
) -> type[DecoratedExpr]:
"""Set commutative and 'extended real' assumptions on expression class.

.. seealso:: :doc:`sympy:guides/assumptions`
"""
decorated_class.is_commutative = True # type: ignore[attr-defined]
decorated_class.is_extended_real = True # type: ignore[attr-defined]
return decorated_class


def create_expression(
cls: type[DecoratedExpr],
*args,
evaluate: bool = False,
name: str | None = None,
**kwargs,
) -> DecoratedExpr:
"""Helper function for implementing `UnevaluatedExpression.__new__`."""
args = sp.sympify(args)
if issubclass(cls, UnevaluatedExpression):
expr = UnevaluatedExpression.__new__(cls, *args, name=name, **kwargs)
if evaluate:
return expr.evaluate() # type: ignore[return-value]
return expr # type: ignore[return-value]
return sp.Expr.__new__(cls, *args, **kwargs) # type: ignore[return-value]


def create_symbol_matrix(name: str, m: int, n: int) -> sp.MutableDenseMatrix:
"""Create a `~sympy.matrices.dense.Matrix` with symbols as elements.

Expand All @@ -332,8 +107,7 @@
return sp.Matrix([[symbol[i, j] for j in range(n)] for i in range(m)])


@implement_doit_method
class PoolSum(UnevaluatedExpression):
class PoolSum(sp.Expr):
r"""Sum over indices where the values are taken from a domain set.

>>> i, j, m, n = sp.symbols("i j m n")
Expand All @@ -352,6 +126,7 @@
cls,
expression,
*indices: tuple[sp.Symbol, Iterable[sp.Basic]],
evaluate: bool = False,
**hints,
) -> PoolSum:
converted_indices = []
Expand All @@ -361,7 +136,11 @@
msg = f"No values provided for index {idx_symbol}"
raise ValueError(msg)
converted_indices.append((idx_symbol, values))
return create_expression(cls, expression, *converted_indices, **hints)
args = sp.sympify((expression, *converted_indices))
expr: PoolSum = sp.Expr.__new__(cls, *args, **hints)
if evaluate:
return expr.evaluate() # type: ignore[return-value]

Check warning on line 142 in src/ampform/sympy/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/__init__.py#L142

Added line #L142 was not covered by tests
return expr

@property
def expression(self) -> sp.Expr:
Expand All @@ -375,6 +154,12 @@
def free_symbols(self) -> set[sp.Basic]:
return super().free_symbols - {s for s, _ in self.indices}

def doit(self, deep: bool = True) -> sp.Expr: # type: ignore[override]
expr = self.evaluate()
if deep:
return expr.doit()
return expr

Check warning on line 161 in src/ampform/sympy/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/__init__.py#L161

Added line #L161 was not covered by tests

def evaluate(self) -> sp.Expr:
indices = {symbol: tuple(values) for symbol, values in self.indices}
return sp.Add(*[
Expand Down
Loading
Loading