From ce5d6aed1775cd3d41ac1c393d69d04786f110a0 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 15 Dec 2023 12:40:11 +0100 Subject: [PATCH] FEAT: implement `unevaluated_expression` decorator (#365) * DOC: add notebook for SymPy helper functions * DOC: illustrate usage of `PoolSum` * DX: skip ipywidgets tests that copy widgets * DX: upgrade developer environment to Python 3.11 * ENH: make implementation method public as `evaluate()` * FEAT: implement `unevaluated_expression` decorator * MAINT: update Codecov config style --- .cspell.json | 1 + .github/workflows/ci.yml | 3 + .gitpod.yml | 4 +- .pre-commit-config.yaml | 1 + .readthedocs.yml | 4 +- .vscode/settings.json | 1 + codecov.yml | 6 +- docs/conf.py | 5 +- docs/usage.ipynb | 1 + docs/usage/sympy.ipynb | 232 ++++++++++++++++++++++ environment.yml | 4 +- pyproject.toml | 4 + src/ampform/kinematics/phasespace.py | 43 ++--- src/ampform/sympy/__init__.py | 16 +- src/ampform/sympy/_decorator.py | 275 +++++++++++++++++++++++++++ tests/symplot/test_symplot.py | 6 +- tests/sympy/test_decorator.py | 127 +++++++++++++ 17 files changed, 694 insertions(+), 39 deletions(-) create mode 100644 docs/usage/sympy.ipynb create mode 100644 src/ampform/sympy/_decorator.py create mode 100644 tests/sympy/test_decorator.py diff --git a/.cspell.json b/.cspell.json index eb60f233f..3171bd53b 100644 --- a/.cspell.json +++ b/.cspell.json @@ -168,6 +168,7 @@ "pyright", "pytestconfig", "rankdir", + "repr", "richman", "rightarrow", "risch", diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 95ed28ea7..109d8fcfe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,6 +33,7 @@ jobs: id-token: write with: apt-packages: graphviz + python-version: "3.11" specific-pip-packages: ${{ inputs.specific-pip-packages }} pytest: uses: ComPWA/actions/.github/workflows/pytest.yml@v1 @@ -45,3 +46,5 @@ jobs: secrets: token: ${{ secrets.PAT }} uses: ComPWA/actions/.github/workflows/pre-commit.yml@v1 + with: + python-version: "3.11" diff --git a/.gitpod.yml b/.gitpod.yml index fb4acd8d6..e8046e720 100644 --- a/.gitpod.yml +++ b/.gitpod.yml @@ -1,6 +1,6 @@ tasks: - - init: pyenv local 3.8 - - init: pip install -c .constraints/py3.8.txt -e .[dev] + - init: pyenv local 3.11 + - init: pip install -c .constraints/py3.11.txt -e .[dev] github: prebuilds: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c0fb6ecab..326af27d4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,6 +20,7 @@ repos: - id: check-dev-files args: - --doc-apt-packages=graphviz + - --dev-python-version=3.11 - --no-prettierrc - --pin-requirements=monthly - --repo-name=ampform diff --git a/.readthedocs.yml b/.readthedocs.yml index f34ae963f..1362c4ee2 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -11,9 +11,9 @@ formats: build: os: ubuntu-22.04 tools: - python: "3.8" + python: "3.11" apt_packages: - graphviz jobs: post_install: - - pip install -c .constraints/py3.8.txt -e .[doc] + - pip install -c .constraints/py3.11.txt -e .[doc] diff --git a/.vscode/settings.json b/.vscode/settings.json index f46e33649..e74f30354 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -63,6 +63,7 @@ "ruff.enable": true, "ruff.organizeImports": true, "search.exclude": { + "typings/**": true, "**/tests/**/__init__.py": true, ".constraints/*.txt": true }, diff --git a/codecov.yml b/codecov.yml index 37356a53c..64157f16e 100644 --- a/codecov.yml +++ b/codecov.yml @@ -14,13 +14,13 @@ coverage: threshold: 1% # allow drops by this percentage base: auto # advanced - branches: null + branches: [] if_no_uploads: error if_not_found: success if_ci_failed: error only_pulls: false - flags: null - paths: null + flags: [] + paths: [] patch: default: # basic diff --git a/docs/conf.py b/docs/conf.py index a0dd94647..4830ebd13 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -52,11 +52,13 @@ "BuilderReturnType": ("obj", "ampform.dynamics.builder.BuilderReturnType"), "DecoratedClass": ("obj", "ampform.sympy.DecoratedClass"), "DecoratedExpr": ("obj", "ampform.sympy.DecoratedExpr"), + "ExprClass": "ampform.sympy.ExprClass", "FourMomenta": ("obj", "ampform.kinematics.FourMomenta"), "FourMomentumSymbol": ("obj", "ampform.kinematics.FourMomentumSymbol"), "InteractionProperties": "qrules.quantum_numbers.InteractionProperties", "LatexPrinter": "sympy.printing.printer.Printer", "Literal[(-1, 1)]": "typing.Literal", + "Literal[-1, 1]": "typing.Literal", "NumPyPrinter": "sympy.printing.printer.Printer", "ParameterValue": ("obj", "ampform.helicity.ParameterValue"), "Particle": "qrules.particle.Particle", @@ -68,6 +70,7 @@ "WignerD": "sympy.physics.quantum.spin.WignerD", "ampform.helicity._T": "typing.TypeVar", "ampform.sympy._decorator.ExprClass": ("obj", "ampform.sympy.ExprClass"), + "ampform.sympy._decorator.SymPyAssumptions": "ampform.sympy.SymPyAssumptions", "an object providing a view on D's values": "typing.ValuesView", "sp.Basic": "sympy.core.basic.Basic", "sp.Expr": "sympy.core.expr.Expr", @@ -283,7 +286,7 @@ nb_execution_timeout = -1 nb_output_stderr = "remove" nitpick_ignore = [ - ("py:class", "ArraySum"), + ("py:class", "ampform.sympy._array_expressions.ArraySum"), ("py:class", "ampform.sympy._array_expressions.MatrixMultiplication"), ] nitpicky = True diff --git a/docs/usage.ipynb b/docs/usage.ipynb index 471bcc5ad..626f0dedc 100644 --- a/docs/usage.ipynb +++ b/docs/usage.ipynb @@ -336,6 +336,7 @@ "usage/helicity/formalism\n", "usage/helicity/spin-alignment\n", "usage/kinematics\n", + "usage/sympy\n", "```" ] } diff --git a/docs/usage/sympy.ipynb b/docs/usage/sympy.ipynb new file mode 100644 index 000000000..b4b715cba --- /dev/null +++ b/docs/usage/sympy.ipynb @@ -0,0 +1,232 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "hideCode": true, + "hideOutput": true, + "hidePrompt": true, + "jupyter": { + "source_hidden": true + }, + "slideshow": { + "slide_type": "skip" + }, + "tags": [ + "remove-cell", + "skip-execution" + ] + }, + "outputs": [], + "source": [ + "# WARNING: advised to install a specific version, e.g. ampform==0.1.2\n", + "%pip install -q ampform[doc,viz] IPython" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "hideCode": true, + "hideOutput": true, + "hidePrompt": true, + "jupyter": { + "source_hidden": true + }, + "slideshow": { + "slide_type": "skip" + }, + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "STATIC_WEB_PAGE = {\"EXECUTE_NB\", \"READTHEDOCS\"}.intersection(os.environ)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```{autolink-concat}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SymPy helper functions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The {mod}`ampform.sympy` module contains a few classes that make it easier to construct larger expressions that consist of several mathematical definitions." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Unevaluated expressions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The {func}`.unevaluated_expression` decorator makes it easier to write classes that represent a mathematical function definition. It makes a class that derives from {class}`sp.Expr ` behave more like a {func}`~.dataclasses.dataclass` (see [PEP 861](https://peps.python.org/pep-0681)). All you have to do is:\n", + "\n", + "1. Specify the arguments the function requires.\n", + "2. Specify how to render the 'unevaluated' or 'folded' form of the expression with a `_latex_repr_` string or method.\n", + "3. Specify how to unfold the expression using an `evaluate()` method.\n", + "\n", + "In the example below, we define a phase space factor $\\rho^\\text{CM}$ using the Chew-Mandelstam function (see PDG Resonances section, [Eq. (50.44)](https://pdg.lbl.gov/2023/reviews/rpp2023-rev-resonances.pdf#page=15)). For this, you need to define a break-up momentum $q$ as well." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sympy as sp\n", + "\n", + "from ampform.sympy import unevaluated_expression\n", + "\n", + "\n", + "@unevaluated_expression(real=False)\n", + "class PhspFactorSWave(sp.Expr):\n", + " s: sp.Symbol\n", + " m1: sp.Symbol\n", + " m2: sp.Symbol\n", + " _latex_repr_ = R\"\\rho^\\text{{CM}}\\left({s}\\right)\"\n", + "\n", + " def evaluate(self) -> sp.Expr:\n", + " s, m1, m2 = self.args\n", + " q = BreakupMomentum(s, m1, m2)\n", + " cm = (\n", + " (2 * q / sp.sqrt(s))\n", + " * sp.log((m1**2 + m2**2 - s + 2 * sp.sqrt(s) * q) / (2 * m1 * m2))\n", + " - (m1**2 - m2**2) * (1 / s - 1 / (m1 + m2) ** 2) * sp.log(m1 / m2)\n", + " ) / (16 * sp.pi**2)\n", + " return 16 * sp.pi * sp.I * cm\n", + "\n", + "\n", + "@unevaluated_expression(real=False)\n", + "class BreakupMomentum(sp.Expr):\n", + " s: sp.Symbol\n", + " m1: sp.Symbol\n", + " m2: sp.Symbol\n", + " _latex_repr_ = R\"q\\left({s}\\right)\"\n", + "\n", + " def evaluate(self) -> sp.Expr:\n", + " s, m1, m2 = self.args\n", + " return sp.sqrt((s - (m1 + m2) ** 2) * (s - (m1 - m2) ** 2) / (s * 4))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As can be seen, the LaTeX rendering of these classes makes them ideal for mathematically defining and building up larger amplitude models:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "from IPython.display import Math\n", + "\n", + "from ampform.io import aslatex\n", + "\n", + "s, m1, m2 = sp.symbols(\"s m1 m2\")\n", + "q_expr = BreakupMomentum(s, m1, m2)\n", + "rho_expr = PhspFactorSWave(s, m1, m2)\n", + "Math(aslatex({e: e.evaluate() for e in [rho_expr, q_expr]}))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The {class}`.PoolSum` class makes it possible to write sums over non-integer ranges. This is for instance useful when summing over allowed helicities. Here are some examples:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ampform.sympy import PoolSum\n", + "\n", + "i, j, m, n = sp.symbols(\"i j m n\")\n", + "expr = PoolSum(i**m + j**n, (i, (-1, 0, +1)), (j, (2, 4, 5)))\n", + "Math(aslatex({expr: expr.doit()}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "A = sp.IndexedBase(\"A\")\n", + "λ, μ = sp.symbols(\"lambda mu\")\n", + "to_range = lambda a, b: tuple(sp.Rational(i) for i in np.arange(a, b + 0.5))\n", + "expr = abs(PoolSum(A[λ, μ], (λ, to_range(-0.5, +0.5)), (μ, to_range(-1, +1)))) ** 2\n", + "Math(aslatex({expr: expr.doit()}))" + ] + } + ], + "metadata": { + "colab": { + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/environment.yml b/environment.yml index e12b86d6b..fa1316535 100644 --- a/environment.yml +++ b/environment.yml @@ -2,11 +2,11 @@ name: ampform channels: - defaults dependencies: - - python==3.8.* + - python==3.11.* - pip - graphviz # for binder - pip: - - -c .constraints/py3.8.txt -e .[dev] + - -c .constraints/py3.11.txt -e .[dev] variables: PRETTIER_LEGACY_CLI: "1" PYTHONHASHSEED: 0 diff --git a/pyproject.toml b/pyproject.toml index 19978b804..9abb45f87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,6 +176,7 @@ branch = true source = ["src"] [tool.mypy] +enable_incomplete_feature = "Unpack" exclude = "_build" show_error_codes = true warn_unused_configs = true @@ -207,6 +208,7 @@ exclude = [ "**/.tox", "**/__pycache__", "**/_build", + "**/typings", ] reportGeneralTypeIssues = false reportIncompatibleMethodOverride = false @@ -224,6 +226,7 @@ reportUnknownVariableType = false reportUnnecessaryComparison = false reportUnnecessaryContains = false reportUnnecessaryIsInstance = false +reportUntypedClassDecorator = false reportUntypedFunctionDecorator = false reportUnusedClass = true reportUnusedFunction = true @@ -324,6 +327,7 @@ task-tags = ["cspell"] known-third-party = ["sympy"] [tool.ruff.per-file-ignores] +"**/docs/usage/sympy.ipynb" = ["E731"] "*.ipynb" = [ "B018", "C408", diff --git a/src/ampform/kinematics/phasespace.py b/src/ampform/kinematics/phasespace.py index f82567ae5..8796a8812 100644 --- a/src/ampform/kinematics/phasespace.py +++ b/src/ampform/kinematics/phasespace.py @@ -5,23 +5,25 @@ from __future__ import annotations +from typing import Any + import sympy as sp -from ampform.sympy import ( - UnevaluatedExpression, - create_expression, - implement_doit_method, - make_commutative, -) +from ampform.sympy import unevaluated_expression -@make_commutative -@implement_doit_method -class Kibble(UnevaluatedExpression): +@unevaluated_expression +class Kibble(sp.Expr): """Kibble function for determining the phase space region.""" - def __new__(cls, sigma1, sigma2, sigma3, m0, m1, m2, m3, **hints) -> Kibble: - return create_expression(cls, sigma1, sigma2, sigma3, m0, m1, m2, m3, **hints) + sigma1: Any + sigma2: Any + sigma3: Any + m0: Any + m1: Any + m2: Any + m3: Any + _latex_repr_ = R"\phi\left({sigma1}, {sigma2}\right)" def evaluate(self) -> Kallen: sigma1, sigma2, sigma3, m0, m1, m2, m3 = self.args @@ -31,27 +33,20 @@ def evaluate(self) -> Kallen: Kallen(sigma1, m1**2, m0**2), # type: ignore[operator] ) - def _latex(self, printer, *args): - sigma1, sigma2, *_ = map(printer._print, self.args) - return Rf"\phi\left({sigma1}, {sigma2}\right)" - -@make_commutative -@implement_doit_method -class Kallen(UnevaluatedExpression): +@unevaluated_expression +class Kallen(sp.Expr): """Källén function, used for computing break-up momenta.""" - def __new__(cls, x, y, z, **hints) -> Kallen: - return create_expression(cls, x, y, z, **hints) + x: Any + y: Any + z: Any + _latex_repr_ = R"\lambda\left({x}, {y}, {z}\right)" def evaluate(self) -> sp.Expr: x, y, z = self.args return x**2 + y**2 + z**2 - 2 * x * y - 2 * y * z - 2 * z * x # type: ignore[operator] - def _latex(self, printer, *args): - x, y, z = map(printer._print, self.args) - return Rf"\lambda\left({x}, {y}, {z}\right)" - def is_within_phasespace( sigma1, sigma2, m0, m1, m2, m3, outside_value=sp.nan diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index 8152b69b9..74577f604 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -1,4 +1,12 @@ -"""Tools that facilitate in building :mod:`sympy` expressions.""" +"""Tools that facilitate in building :mod:`sympy` expressions. + +.. autodecorator:: unevaluated_expression +.. dropdown:: SymPy assumptions + + .. autodata:: ExprClass + .. autoclass:: SymPyAssumptions + +""" # cspell:ignore mhash from __future__ import annotations @@ -17,6 +25,12 @@ import sympy as sp from sympy.printing.precedence import PRECEDENCE +from ._decorator import ( + ExprClass, # noqa: F401 # pyright: ignore[reportUnusedImport] + SymPyAssumptions, # noqa: F401 # pyright: ignore[reportUnusedImport] + unevaluated_expression, # noqa: F401 # pyright: ignore[reportUnusedImport] +) + if TYPE_CHECKING: from sympy.printing.latex import LatexPrinter from sympy.printing.numpy import NumPyPrinter diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py new file mode 100644 index 000000000..e1f68400b --- /dev/null +++ b/src/ampform/sympy/_decorator.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +import functools +import inspect +import sys +from typing import TYPE_CHECKING, Callable, Iterable, TypeVar, overload + +import sympy as sp + +if sys.version_info < (3, 8): + from typing_extensions import Protocol, TypedDict +else: + from typing import Protocol, TypedDict + +if sys.version_info < (3, 11): + from typing_extensions import ParamSpec, Unpack, dataclass_transform +else: + from typing import ParamSpec, Unpack, dataclass_transform + +if TYPE_CHECKING: + from sympy.printing.latex import LatexPrinter + +ExprClass = TypeVar("ExprClass", bound=sp.Expr) +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +class SymPyAssumptions(TypedDict, total=False): + """See https://docs.sympy.org/latest/guides/assumptions.html#predicates.""" + + algebraic: bool + commutative: bool + complex: bool + extended_negative: bool + extended_nonnegative: bool + extended_nonpositive: bool + extended_nonzero: bool + extended_positive: bool + extended_real: bool + finite: bool + hermitian: bool + imaginary: bool + infinite: bool + integer: bool + irrational: bool + negative: bool + noninteger: bool + nonnegative: bool + nonpositive: bool + nonzero: bool + positive: bool + rational: bool + real: bool + transcendental: bool + zero: bool + + +@overload +def unevaluated_expression(cls: type[ExprClass]) -> type[ExprClass]: ... +@overload +def unevaluated_expression( + *, + implement_doit: bool = True, + **assumptions: Unpack[SymPyAssumptions], +) -> Callable[[type[ExprClass]], type[ExprClass]]: ... + + +@dataclass_transform() # type: ignore[misc] +def unevaluated_expression( # type: ignore[misc] + cls: type[ExprClass] | None = None, *, implement_doit=True, **assumptions +): + r"""Decorator for defining 'unevaluated' SymPy expressions. + + Unevaluated expressions are handy for defining large expressions that consist of + several sub-definitions. + + >>> @unevaluated_expression + ... class MyExpr(sp.Expr): + ... x: sp.Symbol + ... y: sp.Symbol + ... _latex_repr_ = R"z\left({x}, {y}\right)" + ... + ... def evaluate(self) -> sp.Expr: + ... x, y = self.args + ... return x**2 + y**2 + ... + >>> a, b = sp.symbols("a b") + >>> expr = MyExpr(a, b**2) + >>> sp.latex(expr) + 'z\\left(a, b^{2}\\right)' + >>> expr.doit() + a**2 + b**4 + """ + if assumptions is None: + assumptions = {} + if not assumptions.get("commutative"): + assumptions["commutative"] = True + + def decorator(cls: type[ExprClass]) -> type[ExprClass]: + cls = _implement_new_method(cls) + if implement_doit: + cls = _implement_doit(cls) + if hasattr(cls, "_latex_repr_"): + cls = _implement_latex_repr(cls) + _set_assumptions(**assumptions)(cls) + return cls + + if cls is None: + return decorator + return decorator(cls) + + +@dataclass_transform() +def _implement_new_method(cls: type[ExprClass]) -> type[ExprClass]: + """Implement the :meth:`__new__` method for dataclass-like SymPy expression classes. + + >>> @_implement_new_method + ... class MyExpr(sp.Expr): + ... a: sp.Symbol + ... b: sp.Symbol + ... + >>> x, y = sp.symbols("x y") + >>> expr = MyExpr(x**2, y**2) + >>> expr.a + x**2 + >>> expr.args + (x**2, y**2) + >>> sp.sqrt(expr) + sqrt(MyExpr(x**2, y**2)) + """ + attr_names = _get_attribute_names(cls) + + @functools.wraps(cls.__new__) + @_insert_args_in_signature(attr_names, idx=1) + def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: + attr_values, kwargs = _get_attribute_values(attr_names, *args, **kwargs) + attr_values = sp.sympify(attr_values) + expr = sp.Expr.__new__(cls, *attr_values, **kwargs) + for name, value in zip(attr_names, args): + setattr(expr, name, value) + if evaluate: + return expr.evaluate() + return expr + + cls.__new__ = new_method # type: ignore[method-assign] + return cls + + +def _get_attribute_values(attr_names: list, *args, **kwargs) -> tuple[tuple, dict]: + if len(args) == len(attr_names): + return args, kwargs + if len(args) > len(attr_names): + msg = ( + f"Expecting {len(attr_names)} positional arguments" + f" ({', '.join(attr_names)}), but got {len(args)}" + ) + raise ValueError(msg) + attr_values = list(args) + remaining_attr_names = attr_names[len(args) :] + for name in list(remaining_attr_names): + if name in kwargs: + attr_values.append(kwargs.pop(name)) + remaining_attr_names.pop(0) + if remaining_attr_names: + msg = f"Missing constructor arguments: {', '.join(remaining_attr_names)}" + raise ValueError(msg) + return tuple(attr_values), kwargs + + +class LatexMethod(Protocol): + def __call__(self, printer: LatexPrinter, *args) -> str: ... + + +@dataclass_transform() +def _implement_latex_repr(cls: type[_T]) -> type[_T]: + _latex_repr_: LatexMethod | str | None = getattr(cls, "_latex_repr_", None) + if _latex_repr_ is None: + msg = ( + "You need to define a _latex_repr_ str or method in order to decorate an" + " unevaluated expression with a printer method for LaTeX representation." + ) + raise NotImplementedError(msg) + if callable(_latex_repr_): + cls._latex = _latex_repr_ # type: ignore[attr-defined] + else: + attr_names = _get_attribute_names(cls) + + def latex_method(self, printer: LatexPrinter, *args) -> str: + format_kwargs = { + name: printer._print(getattr(self, name), *args) for name in attr_names + } + return _latex_repr_.format(**format_kwargs) # type: ignore[union-attr] + + cls._latex = latex_method # type: ignore[attr-defined] + return cls + + +@dataclass_transform() +def _implement_doit(cls: type[ExprClass]) -> type[ExprClass]: + _check_has_implementation(cls) + + @functools.wraps(cls.doit) + def doit_method(self, deep: bool = True) -> sp.Expr: + expr = self.evaluate() + if deep: + return expr.doit() + return expr + + cls.doit = doit_method # type: ignore[assignment] + return cls + + +def _check_has_implementation(cls: type) -> None: + implementation_method = getattr(cls, "evaluate", None) + if implementation_method is None: + msg = "Decorated class must have an evaluate() method" + raise ValueError(msg) + if not callable(implementation_method): + msg = "evaluate() must be a callable method" + raise TypeError(msg) + + +def _insert_args_in_signature( + new_params: Iterable[str] | None = None, idx: int = 0 +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + if new_params is None: + new_params = [] + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + original_signature = inspect.signature(func) + original_pars = list(original_signature.parameters.values()) + new_parameters = [ + inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) + for name in new_params + ] + new_parameters = [*original_pars[:idx], *new_parameters, *original_pars[idx:]] + wrapper.__signature__ = inspect.Signature( + parameters=new_parameters, + return_annotation=original_signature.return_annotation, + ) + return wrapper + + return decorator + + +def _get_attribute_names(cls: type) -> list[str]: + """Get the public attributes of a class with dataclass-like semantics. + + >>> class MyClass: + ... a: int + ... b: int + ... _c: int + ... + ... def print(self): ... + ... + >>> _get_attribute_names(MyClass) + ['a', 'b'] + """ + return [v for v in cls.__annotations__ if not callable(v) if not v.startswith("_")] + + +@dataclass_transform() +def _set_assumptions( + **assumptions: Unpack[SymPyAssumptions], +) -> Callable[[type[_T]], type[_T]]: + def class_wrapper(cls: _T) -> _T: + for assumption, value in assumptions.items(): + setattr(cls, f"is_{assumption}", value) + return cls + + return class_wrapper diff --git a/tests/symplot/test_symplot.py b/tests/symplot/test_symplot.py index 395da8473..3d4cf0dee 100644 --- a/tests/symplot/test_symplot.py +++ b/tests/symplot/test_symplot.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import os from copy import deepcopy from typing import Any, Callable, Pattern, no_type_check @@ -14,7 +13,6 @@ from symplot import RangeDefinition, Slider, SliderKwargs -@pytest.mark.skipif("GITHUB_ACTION" in os.environ, reason="ipywidgets instable") class TestSliderKwargs: @pytest.fixture() def slider_kwargs(self) -> SliderKwargs: @@ -89,7 +87,7 @@ def test_repr( assert slider.max == slider_from_repr.max assert slider.value == slider_from_repr.value - @pytest.mark.skipif("CI" not in os.environ, reason="Only works on GitHub Actions") + @pytest.mark.skip(reason="ipywidgets cannot be compied anymore") @pytest.mark.parametrize( ("slider_name", "min_", "max_", "n_steps", "step_size"), [ @@ -136,7 +134,7 @@ def test_set_ranges_exceptions(self, slider_kwargs: SliderKwargs) -> None: with pytest.raises(ValueError, match=r"Number of steps has to be positive"): slider_kwargs.set_ranges({"n": (0, 10, -1)}) - @pytest.mark.skipif("CI" not in os.environ, reason="Only works on GitHub Actions") + @pytest.mark.skip(reason="ipywidgets cannot be compied anymore") def test_set_values( self, slider_kwargs: SliderKwargs, caplog: pytest.LogCaptureFixture ) -> None: diff --git a/tests/sympy/test_decorator.py b/tests/sympy/test_decorator.py new file mode 100644 index 000000000..d1228374a --- /dev/null +++ b/tests/sympy/test_decorator.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import inspect +from typing import Any + +import pytest +import sympy as sp + +from ampform.sympy._decorator import ( + _check_has_implementation, + _implement_latex_repr, + _implement_new_method, + _insert_args_in_signature, + _set_assumptions, + unevaluated_expression, +) + + +def test_check_implementation(): + with pytest.raises(ValueError, match=r"must have an evaluate\(\) method"): + + @_check_has_implementation + class MyExpr1: # pyright: ignore[reportUnusedClass] + pass + + with pytest.raises(TypeError, match=r"evaluate\(\) must be a callable method"): + + @_check_has_implementation + class MyExpr2: # pyright: ignore[reportUnusedClass] + evaluate = "test" + + +def test_implement_latex_repr(): + @_implement_latex_repr + @_implement_new_method + class MyExpr(sp.Expr): + a: sp.Symbol + b: sp.Symbol + _latex_repr_ = R"f\left({a}, {b}\right)" + + alpha, phi = sp.symbols("alpha phi") + expr = MyExpr(alpha, sp.cos(phi)) + assert sp.latex(expr) == R"f\left(\alpha, \cos{\left(\phi \right)}\right)" + + +def test_implement_new_method(): + @_implement_new_method + class MyExpr(sp.Expr): + a: int + b: int + c: int + + with pytest.raises( + ValueError, match=r"^Expecting 3 positional arguments \(a, b, c\), but got 4$" + ): + MyExpr(1, 2, 3, 4) # type: ignore[call-arg] + with pytest.raises(ValueError, match=r"^Missing constructor arguments: c$"): + MyExpr(1, 2) # type: ignore[call-arg] + expr = MyExpr(1, 2, 3) + assert expr.args == (1, 2, 3) + expr = MyExpr(1, b=2, c=3) + assert expr.args == (1, 2, 3) + + +def test_insert_args_in_signature(): + parameters = ["a", "b"] + + @_insert_args_in_signature(parameters) + def my_func(*args, **kwargs) -> int: + return 1 + + signature = inspect.signature(my_func) + assert list(signature.parameters) == [*parameters, "args", "kwargs"] + assert signature.return_annotation == "int" + + +def test_unevaluated_expression(): + @unevaluated_expression + class BreakupMomentum(sp.Expr): + r"""Breakup momentum of a two-body decay :math:`a \to 1+2`.""" + + s: sp.Basic + m1: sp.Basic + m2: sp.Basic + _latex_repr_ = R"q\left({s}\right)" + + def evaluate(self) -> sp.Expr: + s, m1, m2 = self.args + return sp.sqrt((s - (m1 + m2) ** 2) * (s - (m1 - m2) ** 2)) # type: ignore[operator] + + m0, ma, mb = sp.symbols("m0 m_a m_b") + expr = BreakupMomentum(m0**2, ma, mb) + args_str = list(inspect.signature(expr.__new__).parameters) + assert args_str == ["s", "m1", "m2", "args", "evaluate", "kwargs"] + latex = sp.latex(expr) + assert latex == R"q\left(m_{0}^{2}\right)" + + +def test_unevaluated_expression_callable(): + @unevaluated_expression(implement_doit=False) + class Squared(sp.Expr): + x: Any + + def evaluate(self) -> sp.Expr: + return self.x**2 + + sqrt = Squared(2) + assert str(sqrt) == "Squared(2)" + assert str(sqrt.doit()) == "Squared(2)" + + @unevaluated_expression(complex=True, implement_doit=False) + class MySqrt(sp.Expr): + x: Any + + expr = MySqrt(-1) + assert expr.is_commutative + assert expr.is_complex # type: ignore[attr-defined] + + +def test_set_assumptions(): + @_set_assumptions(commutative=True, negative=False, real=True) + class MySqrt: ... + + expr = MySqrt() + assert expr.is_commutative # type: ignore[attr-defined] + assert not expr.is_negative # type: ignore[attr-defined] + assert expr.is_real # type: ignore[attr-defined]