Skip to content

Commit

Permalink
MAINT: move expression classes to ampform.sympy.deprecated
Browse files Browse the repository at this point in the history
FIX: add newline before doctest
  • Loading branch information
redeboer committed Dec 22, 2023
1 parent db9cf20 commit 843f3cd
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 239 deletions.
12 changes: 6 additions & 6 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,30 +49,28 @@
add_module_names = False
api_github_repo = f"{ORGANIZATION}/{REPO_NAME}"
api_target_substitutions: dict[str, str | tuple[str, str]] = {
"T": "TypeVar",
"BuilderReturnType": ("obj", "ampform.dynamics.builder.BuilderReturnType"),
"DecoratedClass": ("obj", "ampform.sympy.DecoratedClass"),
"DecoratedExpr": ("obj", "ampform.sympy.DecoratedExpr"),
"ExprClass": "ampform.sympy.ExprClass",
"DecoratedClass": ("obj", "ampform.sympy.deprecated.DecoratedClass"),
"DecoratedExpr": ("obj", "ampform.sympy.deprecated.DecoratedExpr"),
"FourMomenta": ("obj", "ampform.kinematics.lorentz.FourMomenta"),
"FourMomentumSymbol": ("obj", "ampform.kinematics.lorentz.FourMomentumSymbol"),
"InteractionProperties": "qrules.quantum_numbers.InteractionProperties",
"LatexPrinter": "sympy.printing.printer.Printer",
"Literal[(-1, 1)]": "typing.Literal",
"Literal[-1, 1]": "typing.Literal",
"NumPyPrintable": ("class", "ampform.sympy.NumPyPrintable"),
"NumPyPrinter": "sympy.printing.printer.Printer",
"ParameterValue": ("obj", "ampform.helicity.ParameterValue"),
"Particle": "qrules.particle.Particle",
"ReactionInfo": "qrules.transition.ReactionInfo",
"Slider": ("obj", "symplot.Slider"),
"State": "qrules.transition.State",
"StateTransition": "qrules.transition.StateTransition",
"T": "TypeVar",
"Topology": "qrules.topology.Topology",
"WignerD": "sympy.physics.quantum.spin.WignerD",
"ampform.helicity._T": "typing.TypeVar",
"ampform.sympy._decorator.ExprClass": ("obj", "ampform.sympy.ExprClass"),
"ampform.sympy._decorator.SymPyAssumptions": "ampform.sympy.SymPyAssumptions",
"an object providing a view on D's values": "typing.ValuesView",
"sp.Basic": "sympy.core.basic.Basic",
"sp.Expr": "sympy.core.expr.Expr",
"sp.Float": "sympy.core.numbers.Float",
Expand Down Expand Up @@ -289,7 +287,9 @@
nb_output_stderr = "remove"
nitpick_ignore = [
("py:class", "ArraySum"),
("py:class", "ExprClass"),
("py:class", "MatrixMultiplication"),
("py:class", "ampform.sympy._decorator.SymPyAssumptions"),
]
nitpicky = True
primary_domain = "py"
Expand Down
241 changes: 8 additions & 233 deletions src/ampform/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""

# cspell:ignore mhash
from __future__ import annotations

import functools
Expand All @@ -23,7 +22,7 @@
from abc import abstractmethod
from os.path import abspath, dirname, expanduser
from textwrap import dedent
from typing import TYPE_CHECKING, Callable, Iterable, Sequence, SupportsFloat, TypeVar
from typing import TYPE_CHECKING, Iterable, Sequence, SupportsFloat

import sympy as sp
from sympy.printing.conventions import split_super_sub
Expand All @@ -35,6 +34,13 @@
argument, # noqa: F401 # pyright: ignore[reportUnusedImport]
unevaluated, # noqa: F401 # pyright: ignore[reportUnusedImport]
)
from .deprecated import (
UnevaluatedExpression, # noqa: F401 # pyright: ignore[reportUnusedImport]
create_expression, # noqa: F401 # pyright: ignore[reportUnusedImport]
implement_doit_method, # noqa: F401 # pyright: ignore[reportUnusedImport]
implement_expr, # pyright: ignore[reportUnusedImport] # noqa: F401
make_commutative, # pyright: ignore[reportUnusedImport] # noqa: F401
)

if TYPE_CHECKING:
from sympy.printing.latex import LatexPrinter
Expand All @@ -43,126 +49,6 @@
_LOGGER = logging.getLogger(__name__)


class UnevaluatedExpression(sp.Expr):
"""Base class for expression classes with an :meth:`evaluate` method.
Deriving from `~sympy.core.expr.Expr` allows us to keep expression trees condense
before unfolding them with their `~sympy.core.basic.Basic.doit` method. This allows
us to:
1. condense the LaTeX representation of an expression tree by providing a custom
:meth:`_latex` method.
2. overwrite its printer methods (see `NumPyPrintable` and e.g.
:doc:`compwa-org:report/001`).
The `UnevaluatedExpression` base class makes implementations of its derived classes
more secure by enforcing the developer to provide implementations for these methods,
so that SymPy mechanisms work correctly. Decorators like :func:`implement_expr` and
:func:`implement_doit_method` provide convenient means to implement the missing
methods.
.. autolink-preface::
import sympy as sp
from ampform.sympy import UnevaluatedExpression, create_expression
.. automethod:: __new__
.. automethod:: evaluate
.. automethod:: _latex
"""

# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L74-L77
__slots__: tuple[str] = ("_name",)
_name: str | None
"""Optional instance attribute that can be used in LaTeX representations."""

def __new__(
cls: type[DecoratedClass],
*args,
name: str | None = None,
**hints,
) -> DecoratedClass:
"""Constructor for a class derived from `UnevaluatedExpression`.
This :meth:`~object.__new__` method correctly sets the
`~sympy.core.basic.Basic.args`, assumptions etc. Overwrite it in order to
further specify its signature. The function :func:`create_expression` can be
used in its implementation, like so:
>>> class MyExpression(UnevaluatedExpression):
... def __new__(
... cls, x: sp.Symbol, y: sp.Symbol, n: int, **hints
... ) -> "MyExpression":
... return create_expression(cls, x, y, n, **hints)
...
... def evaluate(self) -> sp.Expr:
... x, y, n = self.args
... return (x + y)**n
...
>>> x, y = sp.symbols("x y")
>>> expr = MyExpression(x, y, n=3)
>>> expr
MyExpression(x, y, 3)
>>> expr.evaluate()
(x + y)**3
"""
# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L113-L119
obj = object.__new__(cls)
obj._args = args
obj._assumptions = cls.default_assumptions # type: ignore[attr-defined]
obj._mhash = None
obj._name = name
return obj

def __getnewargs_ex__(self) -> tuple[tuple, dict]:
# Pickling support, see
# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L124-L126
args = tuple(self.args)
kwargs = {"name": self._name}
return args, kwargs

def _hashable_content(self) -> tuple:
# https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165
# name is converted to string because unstable hash for None
return (*super()._hashable_content(), str(self._name))

@abstractmethod
def evaluate(self) -> sp.Expr:
"""Evaluate and 'unfold' this `UnevaluatedExpression` by one level.
>>> from ampform.dynamics import BreakupMomentumSquared
>>> s, m1, m2 = sp.symbols("s m1 m2")
>>> expr = BreakupMomentumSquared(s, m1, m2)
>>> expr
BreakupMomentumSquared(s, m1, m2)
>>> expr.evaluate()
(s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s)
>>> expr.doit(deep=False)
(s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s)
.. note:: When decorating this class with :func:`implement_doit_method`,
its :meth:`evaluate` method is equivalent to
:meth:`~sympy.core.basic.Basic.doit` with :code:`deep=False`.
"""

def _latex(self, printer: LatexPrinter, *args) -> str:
r"""Provide a mathematical Latex representation for pretty printing.
>>> from ampform.dynamics import BreakupMomentumSquared
>>> s, m1 = sp.symbols("s m1")
>>> expr = BreakupMomentumSquared(s, m1, m1)
>>> print(sp.latex(expr))
q^2\left(s\right)
>>> print(sp.latex(expr.doit()))
- m_{1}^{2} + \frac{s}{4}
"""
args = tuple(map(printer._print, self.args))
name = type(self).__name__
if self._name is not None:
name = self._name
return f"{name}{args}"


class NumPyPrintable(sp.Expr):
r"""`~sympy.core.expr.Expr` class that can lambdify to NumPy code.
Expand Down Expand Up @@ -201,117 +87,6 @@ def _numpycode(self, printer: NumPyPrinter, *args) -> str:
"""Lambdify this `NumPyPrintable` class to NumPy code."""


DecoratedClass = TypeVar("DecoratedClass", bound=UnevaluatedExpression)
"""`~typing.TypeVar` for decorators like :func:`implement_doit_method`."""


def implement_expr(
n_args: int,
) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]:
"""Decorator for classes that derive from `UnevaluatedExpression`.
Implement a :meth:`~object.__new__` and :meth:`~sympy.core.basic.Basic.doit` method
for a class that derives from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`).
"""

def decorator(
decorated_class: type[DecoratedClass],
) -> type[DecoratedClass]:
decorated_class = implement_new_method(n_args)(decorated_class)
return implement_doit_method(decorated_class)

return decorator


def implement_new_method(
n_args: int,
) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]:
"""Implement :meth:`UnevaluatedExpression.__new__` on a derived class.
Implement a :meth:`~object.__new__` method for a class that derives from
`~sympy.core.expr.Expr` (via `UnevaluatedExpression`).
"""

def decorator(
decorated_class: type[DecoratedClass],
) -> type[DecoratedClass]:
def new_method(
cls: type[DecoratedClass],
*args: sp.Symbol,
evaluate: bool = False,
**hints,
) -> DecoratedClass:
if len(args) != n_args:
msg = f"{n_args} parameters expected, got {len(args)}"
raise ValueError(msg)
args = sp.sympify(args)
expr = UnevaluatedExpression.__new__(cls, *args)
if evaluate:
return expr.evaluate() # type: ignore[return-value]
return expr

decorated_class.__new__ = new_method # type: ignore[assignment]
return decorated_class

return decorator


def implement_doit_method(
decorated_class: type[DecoratedClass],
) -> type[DecoratedClass]:
"""Implement ``doit()`` method for an `UnevaluatedExpression` class.
Implement a :meth:`~sympy.core.basic.Basic.doit` method for a class that derives
from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). A
:meth:`~sympy.core.basic.Basic.doit` method is an extension of an
:meth:`~.UnevaluatedExpression.evaluate` method in the sense that it can work
recursively on deeper expression trees.
"""

@functools.wraps(decorated_class.doit) # type: ignore[attr-defined]
def doit_method(self: UnevaluatedExpression, deep: bool = True) -> sp.Expr:
expr = self.evaluate()
if deep:
return expr.doit()
return expr

decorated_class.doit = doit_method # type: ignore[assignment]
return decorated_class


DecoratedExpr = TypeVar("DecoratedExpr", bound=sp.Expr)
"""`~typing.TypeVar` for decorators like :func:`make_commutative`."""


def make_commutative(
decorated_class: type[DecoratedExpr],
) -> type[DecoratedExpr]:
"""Set commutative and 'extended real' assumptions on expression class.
.. seealso:: :doc:`sympy:guides/assumptions`
"""
decorated_class.is_commutative = True # type: ignore[attr-defined]
decorated_class.is_extended_real = True # type: ignore[attr-defined]
return decorated_class


def create_expression(
cls: type[DecoratedExpr],
*args,
evaluate: bool = False,
name: str | None = None,
**kwargs,
) -> DecoratedExpr:
"""Helper function for implementing `UnevaluatedExpression.__new__`."""
args = sp.sympify(args)
if issubclass(cls, UnevaluatedExpression):
expr = UnevaluatedExpression.__new__(cls, *args, name=name, **kwargs)
if evaluate:
return expr.evaluate() # type: ignore[return-value]
return expr # type: ignore[return-value]
return sp.Expr.__new__(cls, *args, **kwargs) # type: ignore[return-value]


def create_symbol_matrix(name: str, m: int, n: int) -> sp.MutableDenseMatrix:
"""Create a `~sympy.matrices.dense.Matrix` with symbols as elements.
Expand Down
Loading

0 comments on commit 843f3cd

Please sign in to comment.