From e3413278f79024919cda3082b459b1161b5e2efb Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Thu, 21 Dec 2023 21:16:45 +0100 Subject: [PATCH] ENH: mark `@unevaluated` arguments as non-sympy (#380) * ENH: implement `unevaluated` with `dataclasses` * ENH: implement `__slots__` for non-sympy arguments * ENH: implement `pickle` support * ENH: do not overwrite `_eval_subs` if no non-sympy attributes --- docs/conf.py | 1 + pyproject.toml | 5 + src/ampform/sympy/__init__.py | 3 + src/ampform/sympy/_decorator.py | 270 +++++++++++++++------- tests/sympy/decorator/test_helpers.py | 18 ++ tests/sympy/decorator/test_unevaluated.py | 10 +- 6 files changed, 217 insertions(+), 90 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 5550a1038..82270a326 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -49,6 +49,7 @@ add_module_names = False api_github_repo = f"{ORGANIZATION}/{REPO_NAME}" api_target_substitutions: dict[str, str | tuple[str, str]] = { + "T": "TypeVar", "BuilderReturnType": ("obj", "ampform.dynamics.builder.BuilderReturnType"), "DecoratedClass": ("obj", "ampform.sympy.DecoratedClass"), "DecoratedExpr": ("obj", "ampform.sympy.DecoratedExpr"), diff --git a/pyproject.toml b/pyproject.toml index 0c7eb4d9b..9923f8665 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -328,6 +328,11 @@ task-tags = ["cspell"] [tool.ruff.isort] known-third-party = ["sympy"] +[tool.ruff.lint.flake8-bugbear] +extend-immutable-calls = [ + "ampform.sympy._decorator.argument", +] + [tool.ruff.per-file-ignores] "**/docs/usage/sympy.ipynb" = ["E731"] "*.ipynb" = [ diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index 581448068..15caa32a9 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -1,6 +1,8 @@ """Tools that facilitate in building :mod:`sympy` expressions. .. autodecorator:: unevaluated +.. autofunction:: argument + .. dropdown:: SymPy assumptions .. autodata:: ExprClass @@ -30,6 +32,7 @@ from ._decorator import ( ExprClass, # noqa: F401 # pyright: ignore[reportUnusedImport] SymPyAssumptions, # noqa: F401 # pyright: ignore[reportUnusedImport] + argument, # noqa: F401 # pyright: ignore[reportUnusedImport] unevaluated, # noqa: F401 # pyright: ignore[reportUnusedImport] ) diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index 138ec3698..881940a35 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -1,14 +1,20 @@ from __future__ import annotations +import dataclasses import functools import inspect import sys from collections import abc +from dataclasses import MISSING, Field +from dataclasses import astuple as _get_arguments +from dataclasses import dataclass as _create_dataclass +from dataclasses import field as _create_field +from dataclasses import fields as _get_fields from inspect import isclass +from types import MappingProxyType from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterable, TypeVar, overload import sympy as sp -from attrs import frozen from sympy.core.basic import _aresame from sympy.utilities.exceptions import SymPyDeprecationWarning @@ -67,6 +73,30 @@ class SymPyAssumptions(TypedDict, total=False): zero: bool +@overload +def argument(*, default: T = MISSING, sympify: bool = True) -> T: ... # type: ignore[assignment] +@overload +def argument( + *, default_factory: Callable[[], T] = MISSING, sympify: bool = True # type: ignore[assignment] +) -> T: ... +def argument( + *, + default=MISSING, + default_factory=MISSING, + sympify=True, +): + """Add qualifiers to fields of `unevaluated` SymPy expression classes. + + Creates a :class:`dataclasses.Field` with additional metadata for + :func:`unevaluated` by wrapping around :func:`dataclasses.field`. + """ + return _create_field( + default=default, + default_factory=default_factory, + metadata={"sympify": sympify}, + ) + + @overload def unevaluated(cls: type[ExprClass]) -> type[ExprClass]: ... @overload @@ -75,14 +105,15 @@ def unevaluated( implement_doit: bool = True, **assumptions: Unpack[SymPyAssumptions], ) -> Callable[[type[ExprClass]], type[ExprClass]]: ... -@dataclass_transform() +@dataclass_transform(field_specifiers=(argument, _create_field)) def unevaluated( cls: type[ExprClass] | None = None, *, implement_doit=True, **assumptions ): r"""Decorator for defining 'unevaluated' SymPy expressions. Unevaluated expressions are handy for defining large expressions that consist of - several sub-definitions. + several sub-definitions. They are 'unfolded' to their definition once you call their + :meth`~sympy.core.expr.Expr.doit` method. For example: >>> @unevaluated ... class MyExpr(sp.Expr): @@ -100,6 +131,52 @@ def unevaluated( 'z\\left(a, b^{2}\\right)' >>> expr.doit() a**2 + b**4 + + A LaTeX representation for the unevaluated state can be provided by providing an + `f-string `_ or + method called :code:`_latex_repr_`: + + >>> @unevaluated + ... class Function(sp.Expr): + ... x: sp.Symbol + ... _latex_repr_ = R"f\left({x}\right)" + ... + ... def evaluate(self) -> sp.Expr: + ... return sp.sqrt(self.x) + ... + >>> y = sp.Symbol("y", nonnegative=True) + >>> expr = Function(x=y**2) + >>> sp.latex(expr) + 'f\\left(y^{2}\\right)' + >>> expr.doit() + y + + Attributes to the class are fed to the `~object.__new__` constructor of the + :class`~sympy.core.expr.Expr` class and are therefore also called "arguments". Just + like in the :class:`~sympy.core.expr.Expr` class, these arguments are automatically + `sympified + `_. + Attributes/arguments that should not be sympified with :func:`argument`: + + >>> class Transformation: + ... def __call__(self, x: sp.Basic, y: sp.Basic) -> sp.Expr: ... + ... + >>> @unevaluated + ... class MyExpr(sp.Expr): + ... x: Any + ... y: Any + ... functor: Callable = argument(sympify=False) + ... + ... def evaluate(self) -> sp.Expr: + ... return self.functor(self.x, self.y) + ... + >>> expr = MyExpr(0, y=3.14, functor=Transformation) + >>> isinstance(expr.x, sp.Integer) + True + >>> isinstance(expr.y, sp.Float) + True + >>> expr.functor is Transformation + True """ if assumptions is None: assumptions = {} @@ -120,9 +197,9 @@ def decorator(cls: type[ExprClass]) -> type[ExprClass]: return decorator(cls) -@dataclass_transform() +@dataclass_transform(field_specifiers=(argument, _create_field)) def _implement_new_method(cls: type[ExprClass]) -> type[ExprClass]: - """Implement the :meth:`__new__` method for dataclass-like SymPy expression classes. + """Implement :meth:`~object.__new__` for dataclass-like SymPy expression classes. >>> @_implement_new_method ... class MyExpr(sp.Expr): @@ -138,26 +215,55 @@ def _implement_new_method(cls: type[ExprClass]) -> type[ExprClass]: >>> sp.sqrt(expr) sqrt(MyExpr(x**2, y**2)) """ - attr_names = _get_attribute_names(cls) + cls = _create_dataclass( + init=False, # __new__ method through sp.Expr + repr=False, + eq=False, + order=False, + unsafe_hash=False, + frozen=False, + )(cls) + cls = _update_field_metadata(cls) + sympy_fields = _get_sympy_fields(cls) + non_sympy_fields = tuple(f for f in _get_fields(cls) if not _is_sympify(f)) # type: ignore[arg-type] + cls.__slots__ = tuple(f.name for f in non_sympy_fields) # type: ignore[arg-type] @functools.wraps(cls.__new__) - @_insert_args_in_signature(attr_names, idx=1) + @_insert_args_in_signature([f.name for f in sympy_fields], idx=1) def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: - attr_values, hints = _get_attribute_values(cls, attr_names, *args, **kwargs) - converted_attr_values = _safe_sympify(*attr_values) - expr = sp.Expr.__new__(cls, *converted_attr_values.sympy, **hints) - for name, value in zip(attr_names, converted_attr_values.all_args): - setattr(expr, name, value) - expr._all_args = converted_attr_values.all_args - expr._non_sympy_args = converted_attr_values.non_sympy + fields_with_values, hints = _extract_field_values(cls, *args, **kwargs) + fields_with_sympified_values = { + field: _safe_sympify(field, value) + for field, value in fields_with_values.items() + } + sympy_args = tuple( + value + for field, value in fields_with_sympified_values.items() + if _is_sympify(field) + ) + expr = sp.Expr.__new__(cls, *sympy_args, **hints) + for field, value in fields_with_sympified_values.items(): + setattr(expr, field.name, value) if evaluate: return expr.evaluate() return expr cls.__new__ = new_method # type: ignore[method-assign] - cls._eval_subs = _eval_subs_method # type: ignore[method-assign] + cls.__getnewargs__ = _get_arguments # type: ignore[assignment,method-assign] cls._hashable_content = _hashable_content_method # type: ignore[method-assign] - cls._xreplace = _xreplace_method # type: ignore[method-assign] + if non_sympy_fields: + cls._eval_subs = _eval_subs_method # type: ignore[method-assign] + cls._xreplace = _xreplace_method # type: ignore[method-assign] + return cls + + +def _update_field_metadata(cls: T) -> T: + """Set the :code:`sympify` metadata for all fields of a dataclass-like class.""" + for field in _get_fields(cls): # type: ignore[arg-type] + new_metadata = dict(field.metadata) + if "sympify" not in new_metadata: + new_metadata["sympify"] = True + field.metadata = MappingProxyType(new_metadata) return cls @@ -177,83 +283,64 @@ def _get_hashable_object(obj): return obj -def _get_attribute_values( - cls: type[ExprClass], attr_names: tuple[str, ...], *args, **kwargs -) -> tuple[tuple, dict[str, Any]]: +def _extract_field_values( + cls: type, *args, **kwargs +) -> tuple[dict[Field, Any], dict[str, Any]]: """Extract the attribute values from the constructor arguments. Returns a `tuple` of: - 1. the extracted, ordered attributes as requested by :code:`attr_names`, + 1. the values for the dataclass fields extracted from :code:`*args` and + :code:`**kwargs`, 2. a `dict` of remaining keyword arguments that can be used hints for the constructed :class:`sp.Expr` instance. An attempt is made to get any missing attributes from the type hints in the class definition. """ - if len(args) == len(attr_names): - return args, kwargs - if len(args) > len(attr_names): + fields = _get_fields(cls) + if len(args) == len(fields): + return dict(zip(fields, args)), kwargs + if len(args) > len(fields): msg = ( - f"Expecting {len(attr_names)} positional arguments" - f" ({', '.join(attr_names)}), but got {len(args)}" + f"Expecting {len(fields)} positional arguments" + f" ({', '.join(f.name for f in fields)}), but got {len(args)}" ) raise ValueError(msg) - attr_values = list(args) - remaining_attr_names = list(attr_names[len(args) :]) - for name in list(remaining_attr_names): - if name in kwargs: - attr_values.append(kwargs.pop(name)) - remaining_attr_names.pop(0) - elif hasattr(cls, name): - default_value = getattr(cls, name) - attr_values.append(default_value) - remaining_attr_names.pop(0) - if remaining_attr_names: - msg = f"Missing constructor arguments: {', '.join(remaining_attr_names)}" - raise ValueError(msg) - return tuple(attr_values), kwargs - - -def _safe_sympify(*args: Any) -> _ExprNewArumgents: - all_args = [] - sympy_args = [] - non_sympy_args = [] - for arg in args: - converted_arg, is_sympy = _try_sympify(arg) - if is_sympy: - sympy_args.append(converted_arg) + fields_with_values = dict(zip(fields, args)) + remaining_attrs = fields[len(args) :] + missing: list[str] = [] + for field in remaining_attrs: + if field.name in kwargs: + fields_with_values[field] = kwargs.pop(field.name) + elif field.default is MISSING: + missing.append(field.name) else: - non_sympy_args.append(converted_arg) - all_args.append(converted_arg) - return _ExprNewArumgents( - all_args=tuple(all_args), - sympy=tuple(sympy_args), - non_sympy=tuple(non_sympy_args), - ) - - -def _try_sympify(obj) -> tuple[Any, bool]: - if isinstance(obj, str): - return obj, False - try: - return sp.sympify(obj), True - except (TypeError, SymPyDeprecationWarning, sp.SympifyError): - return obj, False + fields_with_values[field] = field.default + if missing: + msg = f"Missing constructor arguments: {', '.join(missing)}" + raise ValueError(msg) + return fields_with_values, kwargs -@frozen -class _ExprNewArumgents: - all_args: tuple[Any, ...] - sympy: tuple[sp.Basic, ...] - non_sympy: tuple[Any, ...] +def _safe_sympify(field: Field, value: dict[Field, Any]) -> dict[Field, Any]: + if _is_sympify(field): + try: + return sp.sympify(value) + except (sp.SympifyError, TypeError, SymPyDeprecationWarning) as exc: + msg = ( + f"Attribute {field.name} could not be sympified. Did you forget to mark" + " it with argument(sympify=False)?" + ) + raise TypeError(msg) from exc + return value class LatexMethod(Protocol): def __call__(self, printer: LatexPrinter, *args) -> str: ... -@dataclass_transform() +@dataclass_transform(field_specifiers=(argument, _create_field)) def _implement_latex_repr(cls: type[T]) -> type[T]: _latex_repr_: LatexMethod | str | None = getattr(cls, "_latex_repr_", None) if _latex_repr_ is None: @@ -277,7 +364,7 @@ def latex_method(self, printer: LatexPrinter, *args) -> str: return cls -@dataclass_transform() +@dataclass_transform(field_specifiers=(argument, _create_field)) def _implement_doit(cls: type[ExprClass]) -> type[ExprClass]: _check_has_implementation(cls) @@ -352,7 +439,7 @@ def _get_attribute_names(cls: type) -> tuple[str, ...]: ) -@dataclass_transform() +@dataclass_transform(field_specifiers=(argument, _create_field)) def _set_assumptions( **assumptions: Unpack[SymPyAssumptions], ) -> Callable[[type[T]], type[T]]: @@ -367,23 +454,24 @@ def class_wrapper(cls: T) -> T: def _eval_subs_method(self, old, new, **hints): # https://github.com/sympy/sympy/blob/1.12/sympy/core/basic.py#L1117-L1147 hit = False - substituted_attrs = list(self._all_args) - for i, old_attr in enumerate(substituted_attrs): - if not hasattr(old_attr, "_eval_subs"): + old_args = _get_arguments(self) + new_args = list(old_args) + for i, old_arg in enumerate(old_args): + if not hasattr(old_arg, "_eval_subs"): continue - if isclass(old_attr): + if isclass(old_arg): continue - new_attr = old_attr._subs(old, new, **hints) - if not _aresame(new_attr, old_attr): + new_attr = old_arg._subs(old, new, **hints) + if not _aresame(new_attr, old_arg): hit = True - substituted_attrs[i] = new_attr + new_args[i] = new_attr if hit: - rv = self.func(*substituted_attrs) + rv = self.func(*new_args) hack2 = hints.get("hack2", False) if hack2 and self.is_Mul and not rv.is_Mul: # 2-arg hack coefficient = sp.S.One nonnumber = [] - for i in substituted_attrs: + for i in new_args: if i.is_Number: coefficient *= i else: @@ -398,9 +486,13 @@ def _eval_subs_method(self, old, new, **hints): def _hashable_content_method(self) -> tuple: hashable_content = super(sp.Expr, self)._hashable_content() - if not self._non_sympy_args: + if not dataclasses.is_dataclass(self): return hashable_content - remaining_content = (_get_hashable_object(arg) for arg in self._non_sympy_args) + remaining_content = ( + _get_hashable_object(getattr(self, field.name)) + for field in _get_fields(self) + if not _is_sympify(field) + ) return (*hashable_content, *remaining_content) @@ -411,7 +503,7 @@ def _xreplace_method(self, rule) -> tuple[sp.Expr, bool]: if rule: new_args = [] hit = False - for arg in self._all_args: + for arg in _get_arguments(self): if hasattr(arg, "_xreplace") and not isclass(arg): replace_result, is_replaced = arg._xreplace(rule) elif isinstance(rule, abc.Mapping): @@ -425,3 +517,11 @@ def _xreplace_method(self, rule) -> tuple[sp.Expr, bool]: if hit: return self.func(*new_args), True return self, False + + +def _get_sympy_fields(cls) -> tuple: + return tuple(f for f in _get_fields(cls) if _is_sympify(f)) + + +def _is_sympify(field: Field) -> bool: + return bool(field.metadata.get("sympify")) diff --git a/tests/sympy/decorator/test_helpers.py b/tests/sympy/decorator/test_helpers.py index 107ecdf47..576fb6f10 100644 --- a/tests/sympy/decorator/test_helpers.py +++ b/tests/sympy/decorator/test_helpers.py @@ -1,6 +1,7 @@ from __future__ import annotations import inspect +from dataclasses import dataclass, field, fields import pytest import sympy as sp @@ -11,6 +12,7 @@ _implement_new_method, _insert_args_in_signature, _set_assumptions, + _update_field_metadata, ) @@ -80,3 +82,19 @@ class MySqrt: ... assert expr.is_commutative # type: ignore[attr-defined] assert not expr.is_negative # type: ignore[attr-defined] assert expr.is_real # type: ignore[attr-defined] + + +def test_update_field_metadata(): + @_update_field_metadata + @dataclass + class MyClass: + a: int + b: float + name: str = field(metadata={"sympify": False}) + + cls_fields = {f.name: f.metadata["sympify"] for f in fields(MyClass)} + assert cls_fields == { + "a": True, + "b": True, + "name": False, + } diff --git a/tests/sympy/decorator/test_unevaluated.py b/tests/sympy/decorator/test_unevaluated.py index ce9bfdba1..91788c40e 100644 --- a/tests/sympy/decorator/test_unevaluated.py +++ b/tests/sympy/decorator/test_unevaluated.py @@ -5,7 +5,7 @@ import sympy as sp -from ampform.sympy._decorator import unevaluated +from ampform.sympy._decorator import argument, unevaluated def test_classvar_behavior(): @@ -42,7 +42,7 @@ class CannotBeSympified: ... @unevaluated(implement_doit=False) class MyExpr(sp.Expr): sympifiable: Any - non_sympy: CannotBeSympified + non_sympy: CannotBeSympified = argument(sympify=False) obj = CannotBeSympified() expr = MyExpr( @@ -114,7 +114,7 @@ class CannotBeSympified: ... @unevaluated(implement_doit=False) class MyExpr(sp.Expr): x: Any - typ: type[CannotBeSympified] + typ: type[CannotBeSympified] = argument(sympify=False) x = sp.Symbol("x") expr = MyExpr(x, typ=CannotBeSympified) @@ -177,7 +177,7 @@ class Protocol: ... @unevaluated(implement_doit=False) class MyExpr(sp.Expr): x: Any - protocol: type[Protocol] = Protocol + protocol: type[Protocol] = argument(default=Protocol, sympify=False) x, y = sp.symbols("x y") expr = MyExpr(x) @@ -197,7 +197,7 @@ class Protocol2(Protocol): ... @unevaluated(implement_doit=False) class MyExpr(sp.Expr): x: Any - protocol: type[Protocol] = Protocol1 + protocol: type[Protocol] = argument(default=Protocol1, sympify=False) x, y = sp.symbols("x y") expr = MyExpr(x)