diff --git a/src/ampform/kinematics/lorentz.py b/src/ampform/kinematics/lorentz.py index 0b437ac9f..7c6ec79e7 100644 --- a/src/ampform/kinematics/lorentz.py +++ b/src/ampform/kinematics/lorentz.py @@ -135,7 +135,7 @@ class EuclideanNorm(NumPyPrintable): """Take the euclidean norm of an array over axis 1.""" vector: sp.Basic - _latex_repr = R"\left|{vector}\right|" + _latex_repr_ = R"\left|{vector}\right|" def evaluate(self) -> ArraySlice: return sp.sqrt(EuclideanNormSquared(self.vector)) @@ -167,7 +167,7 @@ class InvariantMass(sp.Expr): """Invariant mass of a `.FourMomentumSymbol`.""" momentum: sp.Basic - _latex_repr = "m_{{{momentum}}}" + _latex_repr_ = "m_{{{momentum}}}" def evaluate(self) -> ComplexSqrt: p = self.momentum @@ -180,7 +180,7 @@ class NegativeMomentum(sp.Expr): r"""Invert the spatial components of a `.FourMomentumSymbol`.""" momentum: sp.Basic - _latex_repr = R"-\left({momentum}\right)" + _latex_repr_ = R"-\left({momentum}\right)" def evaluate(self) -> sp.Expr: p = self.momentum @@ -265,7 +265,7 @@ class _BoostZMatrixImplementation(NumPyPrintable): gamma_beta: sp.Basic ones: _OnesArray zeros: _ZerosArray - _latex_repr = R"\boldsymbol{{B_z}}\left({beta}\right)" + _latex_repr_ = R"\boldsymbol{{B_z}}\left({beta}\right)" def _numpycode(self, printer: NumPyPrinter, *args) -> str: printer.module_imports[printer._module].add("array") @@ -285,7 +285,7 @@ class BoostMatrix(sp.Expr): r"""Compute a rank-3 Lorentz boost matrix from a `.FourMomentumSymbol`.""" momentum: sp.Basic - _latex_repr = R"\boldsymbol{{B}}\left({momentum}\right)" + _latex_repr_ = R"\boldsymbol{{B}}\left({momentum}\right)" def as_explicit(self) -> sp.MutableDenseMatrix: momentum = self.momentum @@ -353,7 +353,7 @@ class _BoostMatrixImplementation(NumPyPrintable): b22: sp.Basic b23: sp.Basic b33: sp.Basic - _latex_repr = R"\boldsymbol{{B}}\left({momentum}\right)" + _latex_repr_ = R"\boldsymbol{{B}}\left({momentum}\right)" def _numpycode(self, printer: NumPyPrinter, *args) -> str: _, b00, b01, b02, b03, b11, b12, b13, b22, b23, b33 = self.args @@ -409,7 +409,7 @@ class _RotationYMatrixImplementation(NumPyPrintable): sin_angle: sp.Basic ones: _OnesArray zeros: _ZerosArray - _latex_repr = R"\boldsymbol{{R_y}}\left({angle}\right)" + _latex_repr_ = R"\boldsymbol{{R_y}}\left({angle}\right)" def _numpycode(self, printer: NumPyPrinter, *args) -> str: printer.module_imports[printer._module].add("array") @@ -466,7 +466,7 @@ class _RotationZMatrixImplementation(NumPyPrintable): sin_angle: sp.Basic ones: _OnesArray zeros: _ZerosArray - _latex_repr = R"\boldsymbol{{R_z}}\left({angle}\right)" + _latex_repr_ = R"\boldsymbol{{R_z}}\left({angle}\right)" def _numpycode(self, printer: NumPyPrinter, *args) -> str: printer.module_imports[printer._module].add("array") diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index 708416287..3cf0ac2b7 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -344,10 +344,15 @@ def __call__(self, printer: LatexPrinter, *args) -> str: ... @dataclass_transform(field_specifiers=(argument, _create_field)) def _implement_latex_repr(cls: type[T]) -> type[T]: - _latex_repr_: LatexMethod | str | None = getattr(cls, "_latex_repr_", None) + repr_name = "_latex_repr_" + repr_mistyped = "_latex_repr" + if hasattr(cls, repr_mistyped): + msg = f"Class defines a {repr_mistyped} attribute, but it should be {repr_name}" + raise AttributeError(msg) + _latex_repr_: LatexMethod | str | None = getattr(cls, repr_name, None) if _latex_repr_ is None: msg = ( - "You need to define a _latex_repr_ str or method in order to decorate an" + f"You need to define a {repr_name} str or method in order to decorate an" " unevaluated expression with a printer method for LaTeX representation." ) raise NotImplementedError(msg) diff --git a/tests/kinematics/test_lorentz.py b/tests/kinematics/test_lorentz.py index 096ad2ae5..55048c120 100644 --- a/tests/kinematics/test_lorentz.py +++ b/tests/kinematics/test_lorentz.py @@ -193,6 +193,12 @@ def test_latex(self): class TestInvariantMass: + def test_latex(self): + p = FourMomentumSymbol("p1", shape=[]) + mass = InvariantMass(p) + latex = sp.latex(mass) + assert latex == "m_{p_{1}}" + @pytest.mark.parametrize( ("state_id", "expected_mass"), [