Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BREAK: rename unevaluated_expression() to unevaluated() #379

Merged
merged 13 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
"sharey",
"startswith",
"suptitle",
"sympifiable",
"sympified",
"sympify",
"symplot",
Expand Down
13 changes: 7 additions & 6 deletions docs/usage/sympy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The {func}`.unevaluated_expression` decorator makes it easier to write classes that represent a mathematical function definition. It makes a class that derives from {class}`sp.Expr <sympy.core.expr.Expr>` behave more like a {func}`~.dataclasses.dataclass` (see [PEP&nbsp;861](https://peps.python.org/pep-0681)). All you have to do is:\n",
"The {func}`.unevaluated` decorator makes it easier to write classes that represent a mathematical function definition. It makes a class that derives from {class}`sp.Expr <sympy.core.expr.Expr>` behave more like a {func}`~.dataclasses.dataclass` (see [PEP&nbsp;861](https://peps.python.org/pep-0681)). All you have to do is:\n",
"\n",
"1. Specify the arguments the function requires.\n",
"2. Specify how to render the 'unevaluated' or 'folded' form of the expression with a `_latex_repr_` string or method.\n",
Expand All @@ -98,10 +98,10 @@
"source": [
"import sympy as sp\n",
"\n",
"from ampform.sympy import unevaluated_expression\n",
"from ampform.sympy import unevaluated\n",
"\n",
"\n",
"@unevaluated_expression(real=False)\n",
"@unevaluated(real=False)\n",
"class PhspFactorSWave(sp.Expr):\n",
" s: sp.Symbol\n",
" m1: sp.Symbol\n",
Expand All @@ -119,7 +119,7 @@
" return 16 * sp.pi * sp.I * cm\n",
"\n",
"\n",
"@unevaluated_expression(real=False)\n",
"@unevaluated(real=False)\n",
"class BreakupMomentum(sp.Expr):\n",
" s: sp.Symbol\n",
" m1: sp.Symbol\n",
Expand Down Expand Up @@ -166,7 +166,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Class variables and default arguments to instance arguments are also supported:"
"Class variables and default arguments to instance arguments are also supported. They can either be indicated with {class}`typing.ClassVar` or by not providing a type hint:"
]
},
{
Expand All @@ -180,11 +180,12 @@
"from typing import Any, ClassVar\n",
"\n",
"\n",
"@unevaluated_expression\n",
"@unevaluated\n",
"class FunkyPower(sp.Expr):\n",
" x: Any\n",
" m: int = 1\n",
" default_return: ClassVar[sp.Expr | None] = None\n",
" class_name = \"my name\"\n",
" _latex_repr_ = R\"f_{{{m}}}\\left({x}\\right)\"\n",
"\n",
" def evaluate(self) -> sp.Expr | None:\n",
Expand Down
4 changes: 2 additions & 2 deletions src/ampform/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
UnevaluatedExpression,
determine_indices,
implement_doit_method,
unevaluated_expression,
unevaluated,
)

if TYPE_CHECKING:
from sympy.printing.latex import LatexPrinter


@unevaluated_expression
@unevaluated
class BlattWeisskopfSquared(sp.Expr):
# cspell:ignore pychyGekoppeltePartialwellenanalyseAnnihilationen
r"""Blatt-Weisskopf function :math:`B_L^2(z)`, up to :math:`L \leq 8`.
Expand Down
6 changes: 3 additions & 3 deletions src/ampform/kinematics/phasespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

import sympy as sp

from ampform.sympy import unevaluated_expression
from ampform.sympy import unevaluated


@unevaluated_expression
@unevaluated
class Kibble(sp.Expr):
"""Kibble function for determining the phase space region."""

Expand All @@ -34,7 +34,7 @@ def evaluate(self) -> Kallen:
)


@unevaluated_expression
@unevaluated
class Kallen(sp.Expr):
"""Källén function, used for computing break-up momenta."""

Expand Down
4 changes: 2 additions & 2 deletions src/ampform/sympy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tools that facilitate in building :mod:`sympy` expressions.

.. autodecorator:: unevaluated_expression
.. autodecorator:: unevaluated
.. dropdown:: SymPy assumptions

.. autodata:: ExprClass
Expand Down Expand Up @@ -30,7 +30,7 @@
from ._decorator import (
ExprClass, # noqa: F401 # pyright: ignore[reportUnusedImport]
SymPyAssumptions, # noqa: F401 # pyright: ignore[reportUnusedImport]
unevaluated_expression, # noqa: F401 # pyright: ignore[reportUnusedImport]
unevaluated, # noqa: F401 # pyright: ignore[reportUnusedImport]
)

if TYPE_CHECKING:
Expand Down
179 changes: 159 additions & 20 deletions src/ampform/sympy/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,38 @@
import functools
import inspect
import sys
from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, overload
from collections import abc
from inspect import isclass
from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterable, TypeVar, overload

import sympy as sp
from attrs import frozen
from sympy.core.basic import _aresame
from sympy.utilities.exceptions import SymPyDeprecationWarning

if sys.version_info < (3, 8):
from typing_extensions import Protocol, TypedDict
else:
from typing import Protocol, TypedDict

if sys.version_info < (3, 11):
from typing_extensions import ParamSpec, Unpack, dataclass_transform
from typing_extensions import dataclass_transform
else:
from typing import ParamSpec, Unpack, dataclass_transform
from typing import dataclass_transform

Check warning on line 23 in src/ampform/sympy/_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/_decorator.py#L23

Added line #L23 was not covered by tests

if TYPE_CHECKING:
from sympy.printing.latex import LatexPrinter

if sys.version_info < (3, 11):
from typing_extensions import ParamSpec, Unpack

Check warning on line 29 in src/ampform/sympy/_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/_decorator.py#L29

Added line #L29 was not covered by tests
else:
from typing import ParamSpec, Unpack

Check warning on line 31 in src/ampform/sympy/_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/_decorator.py#L31

Added line #L31 was not covered by tests

H = TypeVar("H", bound=Hashable)
P = ParamSpec("P")
T = TypeVar("T")

Check warning on line 35 in src/ampform/sympy/_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/_decorator.py#L33-L35

Added lines #L33 - L35 were not covered by tests

ExprClass = TypeVar("ExprClass", bound=sp.Expr)
_P = ParamSpec("_P")
_T = TypeVar("_T")


class SymPyAssumptions(TypedDict, total=False):
Expand Down Expand Up @@ -56,25 +68,23 @@


@overload
def unevaluated_expression(cls: type[ExprClass]) -> type[ExprClass]: ...
def unevaluated(cls: type[ExprClass]) -> type[ExprClass]: ...
@overload
def unevaluated_expression(
def unevaluated(
*,
implement_doit: bool = True,
**assumptions: Unpack[SymPyAssumptions],
) -> Callable[[type[ExprClass]], type[ExprClass]]: ...


@dataclass_transform() # type: ignore[misc]
def unevaluated_expression( # type: ignore[misc]
@dataclass_transform()
def unevaluated(
cls: type[ExprClass] | None = None, *, implement_doit=True, **assumptions
):
r"""Decorator for defining 'unevaluated' SymPy expressions.

Unevaluated expressions are handy for defining large expressions that consist of
several sub-definitions.

>>> @unevaluated_expression
>>> @unevaluated
... class MyExpr(sp.Expr):
... x: sp.Symbol
... y: sp.Symbol
Expand Down Expand Up @@ -133,22 +143,54 @@
@functools.wraps(cls.__new__)
@_insert_args_in_signature(attr_names, idx=1)
def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]:
positional_args, hints = _get_attribute_values(cls, attr_names, *args, **kwargs)
sympified_args = sp.sympify(positional_args)
expr = sp.Expr.__new__(cls, *sympified_args, **hints)
for name, value in zip(attr_names, sympified_args):
attr_values, hints = _get_attribute_values(cls, attr_names, *args, **kwargs)
converted_attr_values = _safe_sympify(*attr_values)
expr = sp.Expr.__new__(cls, *converted_attr_values.sympy, **hints)
for name, value in zip(attr_names, converted_attr_values.all_args):
setattr(expr, name, value)
expr._all_args = converted_attr_values.all_args
expr._non_sympy_args = converted_attr_values.non_sympy
if evaluate:
return expr.evaluate()
return expr

cls.__new__ = new_method # type: ignore[method-assign]
cls._eval_subs = _eval_subs_method # type: ignore[method-assign]
cls._hashable_content = _hashable_content_method # type: ignore[method-assign]
cls._xreplace = _xreplace_method # type: ignore[method-assign]
return cls


@overload
def _get_hashable_object(obj: type) -> str: ... # type: ignore[overload-overlap]
@overload
def _get_hashable_object(obj: H) -> H: ...
@overload
def _get_hashable_object(obj: Any) -> str: ...
def _get_hashable_object(obj):
if isclass(obj):
return str(obj)
try:
hash(obj)
except TypeError:
return str(obj)
return obj

Check warning on line 177 in src/ampform/sympy/_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/_decorator.py#L173-L177

Added lines #L173 - L177 were not covered by tests


def _get_attribute_values(
cls: type[ExprClass], attr_names: tuple[str, ...], *args, **kwargs
) -> tuple[tuple, dict[str, Any]]:
"""Extract the attribute values from the constructor arguments.

Returns a `tuple` of:

1. the extracted, ordered attributes as requested by :code:`attr_names`,
2. a `dict` of remaining keyword arguments that can be used hints for the
constructed :class:`sp.Expr<sympy.core.expr.Expr>` instance.

An attempt is made to get any missing attributes from the type hints in the class
definition.
"""
if len(args) == len(attr_names):
return args, kwargs
if len(args) > len(attr_names):
Expand All @@ -173,12 +215,46 @@
return tuple(attr_values), kwargs


def _safe_sympify(*args: Any) -> _ExprNewArumgents:
all_args = []
sympy_args = []
non_sympy_args = []
for arg in args:
converted_arg, is_sympy = _try_sympify(arg)
if is_sympy:
sympy_args.append(converted_arg)
else:
non_sympy_args.append(converted_arg)
all_args.append(converted_arg)
return _ExprNewArumgents(
all_args=tuple(all_args),
sympy=tuple(sympy_args),
non_sympy=tuple(non_sympy_args),
)


def _try_sympify(obj) -> tuple[Any, bool]:
if isinstance(obj, str):
return obj, False

Check warning on line 238 in src/ampform/sympy/_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/_decorator.py#L238

Added line #L238 was not covered by tests
try:
return sp.sympify(obj), True
except (TypeError, SymPyDeprecationWarning, sp.SympifyError):
return obj, False


@frozen
class _ExprNewArumgents:
all_args: tuple[Any, ...]
sympy: tuple[sp.Basic, ...]
non_sympy: tuple[Any, ...]


class LatexMethod(Protocol):
def __call__(self, printer: LatexPrinter, *args) -> str: ...


@dataclass_transform()
def _implement_latex_repr(cls: type[_T]) -> type[_T]:
def _implement_latex_repr(cls: type[T]) -> type[T]:
_latex_repr_: LatexMethod | str | None = getattr(cls, "_latex_repr_", None)
if _latex_repr_ is None:
msg = (
Expand Down Expand Up @@ -228,7 +304,7 @@

def _insert_args_in_signature(
new_params: Iterable[str] | None = None, idx: int = 0
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
if new_params is None:
new_params = []

Expand Down Expand Up @@ -279,10 +355,73 @@
@dataclass_transform()
def _set_assumptions(
**assumptions: Unpack[SymPyAssumptions],
) -> Callable[[type[_T]], type[_T]]:
def class_wrapper(cls: _T) -> _T:
) -> Callable[[type[T]], type[T]]:
def class_wrapper(cls: T) -> T:
for assumption, value in assumptions.items():
setattr(cls, f"is_{assumption}", value)
return cls

return class_wrapper


def _eval_subs_method(self, old, new, **hints):
# https://github.com/sympy/sympy/blob/1.12/sympy/core/basic.py#L1117-L1147
hit = False
substituted_attrs = list(self._all_args)
for i, old_attr in enumerate(substituted_attrs):
if not hasattr(old_attr, "_eval_subs"):
continue
if isclass(old_attr):
continue

Check warning on line 375 in src/ampform/sympy/_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/_decorator.py#L375

Added line #L375 was not covered by tests
new_attr = old_attr._subs(old, new, **hints)
if not _aresame(new_attr, old_attr):
hit = True
substituted_attrs[i] = new_attr
if hit:
rv = self.func(*substituted_attrs)
hack2 = hints.get("hack2", False)
if hack2 and self.is_Mul and not rv.is_Mul: # 2-arg hack
coefficient = sp.S.One
nonnumber = []

Check warning on line 385 in src/ampform/sympy/_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/_decorator.py#L384-L385

Added lines #L384 - L385 were not covered by tests
for i in substituted_attrs:
if i.is_Number:
coefficient *= i

Check warning on line 388 in src/ampform/sympy/_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/_decorator.py#L388

Added line #L388 was not covered by tests
else:
nonnumber.append(i)
nonnumber = self.func(*nonnumber)

Check warning on line 391 in src/ampform/sympy/_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/_decorator.py#L390-L391

Added lines #L390 - L391 were not covered by tests
if coefficient is sp.S.One:
return nonnumber
return self.func(coefficient, nonnumber, evaluate=False)

Check warning on line 394 in src/ampform/sympy/_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/_decorator.py#L393-L394

Added lines #L393 - L394 were not covered by tests
return rv
return self

Check warning on line 396 in src/ampform/sympy/_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/_decorator.py#L396

Added line #L396 was not covered by tests


def _hashable_content_method(self) -> tuple:
hashable_content = super(sp.Expr, self)._hashable_content()
if not self._non_sympy_args:
return hashable_content
remaining_content = (_get_hashable_object(arg) for arg in self._non_sympy_args)
return (*hashable_content, *remaining_content)


def _xreplace_method(self, rule) -> tuple[sp.Expr, bool]:
# https://github.com/sympy/sympy/blob/1.12/sympy/core/basic.py#L1233-L1253
if self in rule:
return rule[self], True

Check warning on line 410 in src/ampform/sympy/_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/_decorator.py#L410

Added line #L410 was not covered by tests
if rule:
new_args = []
hit = False
for arg in self._all_args:
if hasattr(arg, "_xreplace") and not isclass(arg):
replace_result, is_replaced = arg._xreplace(rule)
elif isinstance(rule, abc.Mapping):
is_replaced = bool(arg in rule)
replace_result = rule.get(arg, arg)
else:
replace_result = arg
is_replaced = False

Check warning on line 422 in src/ampform/sympy/_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/ampform/sympy/_decorator.py#L421-L422

Added lines #L421 - L422 were not covered by tests
new_args.append(replace_result)
hit |= is_replaced
if hit:
return self.func(*new_args), True
return self, False
Loading
Loading