diff --git a/.cspell.json b/.cspell.json index c7e0ea1f1..871b55b36 100644 --- a/.cspell.json +++ b/.cspell.json @@ -182,6 +182,7 @@ "sharey", "startswith", "suptitle", + "sympified", "sympify", "symplot", "theano", diff --git a/docs/usage/sympy.ipynb b/docs/usage/sympy.ipynb index b4b715cba..f93e42676 100644 --- a/docs/usage/sympy.ipynb +++ b/docs/usage/sympy.ipynb @@ -162,6 +162,56 @@ "Math(aslatex({e: e.evaluate() for e in [rho_expr, q_expr]}))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Class variables and default arguments to instance arguments are also supported:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "\n", + "from typing import Any, ClassVar\n", + "\n", + "\n", + "@unevaluated_expression\n", + "class FunkyPower(sp.Expr):\n", + " x: Any\n", + " m: int = 1\n", + " default_return: ClassVar[sp.Expr | None] = None\n", + " _latex_repr_ = R\"f_{{{m}}}\\left({x}\\right)\"\n", + "\n", + " def evaluate(self) -> sp.Expr | None:\n", + " if self.default_return is None:\n", + " return self.x**self.m\n", + " return self.default_return\n", + "\n", + "\n", + "x = sp.Symbol(\"x\")\n", + "exprs = (\n", + " FunkyPower(x),\n", + " FunkyPower(x, 2),\n", + " FunkyPower(x, m=3),\n", + ")\n", + "Math(aslatex({e: e.doit() for e in exprs}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "FunkyPower.default_return = sp.Rational(0.5)\n", + "Math(aslatex({e: e.doit() for e in exprs}))" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index e8501b84e..a3e42ce2c 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -133,10 +133,10 @@ def _implement_new_method(cls: type[ExprClass]) -> type[ExprClass]: @functools.wraps(cls.__new__) @_insert_args_in_signature(attr_names, idx=1) def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: - attr_values, kwargs = _get_attribute_values(attr_names, *args, **kwargs) - attr_values = sp.sympify(attr_values) - expr = sp.Expr.__new__(cls, *attr_values, **kwargs) - for name, value in zip(attr_names, attr_values): + positional_args, hints = _get_attribute_values(cls, attr_names, *args, **kwargs) + sympified_args = sp.sympify(positional_args) + expr = sp.Expr.__new__(cls, *sympified_args, **hints) + for name, value in zip(attr_names, sympified_args): setattr(expr, name, value) if evaluate: return expr.evaluate() @@ -147,7 +147,7 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: def _get_attribute_values( - attr_names: tuple[str, ...], *args, **kwargs + cls: type[ExprClass], attr_names: tuple[str, ...], *args, **kwargs ) -> tuple[tuple, dict[str, Any]]: if len(args) == len(attr_names): return args, kwargs @@ -163,6 +163,10 @@ def _get_attribute_values( if name in kwargs: attr_values.append(kwargs.pop(name)) remaining_attr_names.pop(0) + elif hasattr(cls, name): + default_value = getattr(cls, name) + attr_values.append(default_value) + remaining_attr_names.pop(0) if remaining_attr_names: msg = f"Missing constructor arguments: {', '.join(remaining_attr_names)}" raise ValueError(msg) diff --git a/tests/sympy/test_decorator.py b/tests/sympy/test_decorator.py index e446ab5a0..b9e4209a2 100644 --- a/tests/sympy/test_decorator.py +++ b/tests/sympy/test_decorator.py @@ -134,6 +134,43 @@ def evaluate(self) -> sp.Expr: assert y_expr.doit() == 5**3 +def test_unevaluated_expression_default_argument(): + @unevaluated_expression + class FunkyPower(sp.Expr): + x: Any + m: int = 1 + default_return: ClassVar[float | None] = None + + def evaluate(self) -> sp.Expr: + if self.default_return is None: + return self.x**self.m + return sp.sympify(self.default_return) + + x = sp.Symbol("x") + exprs = ( + FunkyPower(x), + FunkyPower(x, 2), + FunkyPower(x, m=3), + ) + assert exprs[0].doit() == x + assert exprs[1].doit() == x**2 + assert exprs[2].doit() == x**3 + for expr in exprs: + assert expr.x is x + assert isinstance(expr.m, sp.Integer) + assert expr.default_return is None + + half = sp.Rational(1, 2) + FunkyPower.default_return = half + assert exprs[0].doit() == half + assert exprs[1].doit() == half + assert exprs[2].doit() == half + for expr in exprs: + assert expr.x is x + assert isinstance(expr.m, sp.Integer) + assert expr.default_return is half + + def test_unevaluated_expression_callable(): @unevaluated_expression(implement_doit=False) class Squared(sp.Expr): @@ -153,3 +190,20 @@ class MySqrt(sp.Expr): expr = MySqrt(-1) assert expr.is_commutative assert expr.is_complex # type: ignore[attr-defined] + + +def test_unevaluated_expression_default_args(): + @unevaluated_expression + class MyExpr(sp.Expr): + x: Any + m: int = 2 + + def evaluate(self) -> sp.Expr: + return self.x**self.m + + expr1 = MyExpr(x=5) + assert str(expr1) == "MyExpr(5, 2)" + assert expr1.doit() == 5**2 + + expr2 = MyExpr(4, 3) + assert expr2.doit() == 4**3