From a072b4e4ef43bfae0b8f9ec3ab4f1fc6803c232f Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 29 Jan 2024 12:16:11 +0100 Subject: [PATCH 1/3] MAINT: write latex test for `InvariantMass` --- tests/kinematics/test_lorentz.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/kinematics/test_lorentz.py b/tests/kinematics/test_lorentz.py index 096ad2ae5..55048c120 100644 --- a/tests/kinematics/test_lorentz.py +++ b/tests/kinematics/test_lorentz.py @@ -193,6 +193,12 @@ def test_latex(self): class TestInvariantMass: + def test_latex(self): + p = FourMomentumSymbol("p1", shape=[]) + mass = InvariantMass(p) + latex = sp.latex(mass) + assert latex == "m_{p_{1}}" + @pytest.mark.parametrize( ("state_id", "expected_mass"), [ From b073463f3029c7099daa68ff279ca3292e7444f8 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 29 Jan 2024 12:16:46 +0100 Subject: [PATCH 2/3] FIX: add back expr class LaTeX rendering --- src/ampform/kinematics/lorentz.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/ampform/kinematics/lorentz.py b/src/ampform/kinematics/lorentz.py index 0b437ac9f..7c6ec79e7 100644 --- a/src/ampform/kinematics/lorentz.py +++ b/src/ampform/kinematics/lorentz.py @@ -135,7 +135,7 @@ class EuclideanNorm(NumPyPrintable): """Take the euclidean norm of an array over axis 1.""" vector: sp.Basic - _latex_repr = R"\left|{vector}\right|" + _latex_repr_ = R"\left|{vector}\right|" def evaluate(self) -> ArraySlice: return sp.sqrt(EuclideanNormSquared(self.vector)) @@ -167,7 +167,7 @@ class InvariantMass(sp.Expr): """Invariant mass of a `.FourMomentumSymbol`.""" momentum: sp.Basic - _latex_repr = "m_{{{momentum}}}" + _latex_repr_ = "m_{{{momentum}}}" def evaluate(self) -> ComplexSqrt: p = self.momentum @@ -180,7 +180,7 @@ class NegativeMomentum(sp.Expr): r"""Invert the spatial components of a `.FourMomentumSymbol`.""" momentum: sp.Basic - _latex_repr = R"-\left({momentum}\right)" + _latex_repr_ = R"-\left({momentum}\right)" def evaluate(self) -> sp.Expr: p = self.momentum @@ -265,7 +265,7 @@ class _BoostZMatrixImplementation(NumPyPrintable): gamma_beta: sp.Basic ones: _OnesArray zeros: _ZerosArray - _latex_repr = R"\boldsymbol{{B_z}}\left({beta}\right)" + _latex_repr_ = R"\boldsymbol{{B_z}}\left({beta}\right)" def _numpycode(self, printer: NumPyPrinter, *args) -> str: printer.module_imports[printer._module].add("array") @@ -285,7 +285,7 @@ class BoostMatrix(sp.Expr): r"""Compute a rank-3 Lorentz boost matrix from a `.FourMomentumSymbol`.""" momentum: sp.Basic - _latex_repr = R"\boldsymbol{{B}}\left({momentum}\right)" + _latex_repr_ = R"\boldsymbol{{B}}\left({momentum}\right)" def as_explicit(self) -> sp.MutableDenseMatrix: momentum = self.momentum @@ -353,7 +353,7 @@ class _BoostMatrixImplementation(NumPyPrintable): b22: sp.Basic b23: sp.Basic b33: sp.Basic - _latex_repr = R"\boldsymbol{{B}}\left({momentum}\right)" + _latex_repr_ = R"\boldsymbol{{B}}\left({momentum}\right)" def _numpycode(self, printer: NumPyPrinter, *args) -> str: _, b00, b01, b02, b03, b11, b12, b13, b22, b23, b33 = self.args @@ -409,7 +409,7 @@ class _RotationYMatrixImplementation(NumPyPrintable): sin_angle: sp.Basic ones: _OnesArray zeros: _ZerosArray - _latex_repr = R"\boldsymbol{{R_y}}\left({angle}\right)" + _latex_repr_ = R"\boldsymbol{{R_y}}\left({angle}\right)" def _numpycode(self, printer: NumPyPrinter, *args) -> str: printer.module_imports[printer._module].add("array") @@ -466,7 +466,7 @@ class _RotationZMatrixImplementation(NumPyPrintable): sin_angle: sp.Basic ones: _OnesArray zeros: _ZerosArray - _latex_repr = R"\boldsymbol{{R_z}}\left({angle}\right)" + _latex_repr_ = R"\boldsymbol{{R_z}}\left({angle}\right)" def _numpycode(self, printer: NumPyPrinter, *args) -> str: printer.module_imports[printer._module].add("array") From 600081f909e626a45c0fe50b88b102663a62b3aa Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 29 Jan 2024 12:20:47 +0100 Subject: [PATCH 3/3] ENH: check for `_latex_repr_` typo --- src/ampform/sympy/_decorator.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index 708416287..3cf0ac2b7 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -344,10 +344,15 @@ def __call__(self, printer: LatexPrinter, *args) -> str: ... @dataclass_transform(field_specifiers=(argument, _create_field)) def _implement_latex_repr(cls: type[T]) -> type[T]: - _latex_repr_: LatexMethod | str | None = getattr(cls, "_latex_repr_", None) + repr_name = "_latex_repr_" + repr_mistyped = "_latex_repr" + if hasattr(cls, repr_mistyped): + msg = f"Class defines a {repr_mistyped} attribute, but it should be {repr_name}" + raise AttributeError(msg) + _latex_repr_: LatexMethod | str | None = getattr(cls, repr_name, None) if _latex_repr_ is None: msg = ( - "You need to define a _latex_repr_ str or method in order to decorate an" + f"You need to define a {repr_name} str or method in order to decorate an" " unevaluated expression with a printer method for LaTeX representation." ) raise NotImplementedError(msg)