From 6dbf59582edae3135452f07d8764af69ed0aca23 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 12 Feb 2024 11:16:29 +0100 Subject: [PATCH] DOC: explain usage of `argument()` function (#396) * ENH: emit warning if `_latex_repr_` is mistyped * FIX: use correct docstring syntax --- docs/usage/sympy.ipynb | 66 +++++++++++++++++++++++ src/ampform/sympy/_decorator.py | 12 +++-- tests/sympy/decorator/test_unevaluated.py | 16 ++++++ 3 files changed, 89 insertions(+), 5 deletions(-) diff --git a/docs/usage/sympy.ipynb b/docs/usage/sympy.ipynb index bc35b5cbf..902676bdf 100644 --- a/docs/usage/sympy.ipynb +++ b/docs/usage/sympy.ipynb @@ -213,6 +213,72 @@ "Math(aslatex({e: e.doit() for e in exprs}))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By default, instance attributes are converted ['sympified'](https://docs.sympy.org/latest/modules/core.html#module-sympy.core.sympify). To avoid this behavior, use the {func}`.argument` function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable\n", + "\n", + "from ampform.sympy import argument\n", + "\n", + "\n", + "class Transformation:\n", + " def __init__(self, power: int) -> None:\n", + " self.power = power\n", + "\n", + " def __call__(self, x: sp.Basic, y: sp.Basic) -> sp.Expr:\n", + " return x + y**self.power\n", + "\n", + "\n", + "@unevaluated\n", + "class MyExpr(sp.Expr):\n", + " x: Any\n", + " y: Any\n", + " functor: Callable = argument(sympify=False)\n", + "\n", + " def evaluate(self) -> sp.Expr:\n", + " return self.functor(self.x, self.y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice how the `functor` attribute has not been sympified (there is no SymPy equivalent for a callable object), but the `functor` can be called in the `evaluate()`/`doit()` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a, b, k = sp.symbols(\"a b k\")\n", + "expr = MyExpr(a, y=b, functor=Transformation(power=k))\n", + "assert expr.x is a\n", + "assert expr.y is b\n", + "assert not isinstance(expr.functor, sp.Basic)\n", + "Math(aslatex({expr: expr.doit()}))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::{tip}\n", + "An example where this is used, is in the {class}`.EnergyDependentWidth` class, where we do not want to sympify the {attr}`~.EnergyDependentWidth.phsp_factor` protocol.\n", + ":::" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index 3cf0ac2b7..89ab23d29 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -4,6 +4,7 @@ import functools import inspect import sys +import warnings from collections import abc from dataclasses import MISSING, Field from dataclasses import astuple as _get_arguments @@ -152,7 +153,7 @@ def unevaluated( y Attributes to the class are fed to the `~object.__new__` constructor of the - :class`~sympy.core.expr.Expr` class and are therefore also called "arguments". Just + :class:`~sympy.core.expr.Expr` class and are therefore also called "arguments". Just like in the :class:`~sympy.core.expr.Expr` class, these arguments are automatically `sympified `_. @@ -187,6 +188,11 @@ def decorator(cls: type[ExprClass]) -> type[ExprClass]: cls = _implement_new_method(cls) if implement_doit: cls = _implement_doit(cls) + typos = ["_latex_repr"] + for typo in typos: + if hasattr(cls, typo): + msg = f"Class defines a {typo} attribute, but it should be _latex_repr_" + warnings.warn(msg, category=UserWarning, stacklevel=1) if hasattr(cls, "_latex_repr_"): cls = _implement_latex_repr(cls) _set_assumptions(**assumptions)(cls) @@ -345,10 +351,6 @@ def __call__(self, printer: LatexPrinter, *args) -> str: ... @dataclass_transform(field_specifiers=(argument, _create_field)) def _implement_latex_repr(cls: type[T]) -> type[T]: repr_name = "_latex_repr_" - repr_mistyped = "_latex_repr" - if hasattr(cls, repr_mistyped): - msg = f"Class defines a {repr_mistyped} attribute, but it should be {repr_name}" - raise AttributeError(msg) _latex_repr_: LatexMethod | str | None = getattr(cls, repr_name, None) if _latex_repr_ is None: msg = ( diff --git a/tests/sympy/decorator/test_unevaluated.py b/tests/sympy/decorator/test_unevaluated.py index 4849b8618..163e9e7ea 100644 --- a/tests/sympy/decorator/test_unevaluated.py +++ b/tests/sympy/decorator/test_unevaluated.py @@ -3,6 +3,7 @@ import inspect from typing import Any, ClassVar +import pytest import sympy as sp from ampform.sympy._decorator import argument, unevaluated @@ -124,6 +125,21 @@ class MyExpr(sp.Expr): ) +def test_latex_repr_typo_warning(): + with pytest.warns( + UserWarning, + match=r"Class defines a _latex_repr attribute, but it should be _latex_repr_", + ): + + @unevaluated(real=False) + class MyExpr(sp.Expr): # pyright: ignore[reportUnusedClass] + x: sp.Symbol + _latex_repr = "" + + def evaluate(self) -> sp.Expr: + return self.x + + def test_no_implement_doit(): @unevaluated(implement_doit=False) class Squared(sp.Expr):