Skip to content

Commit

Permalink
FIX: implement correct latex repr method name (#393)
Browse files Browse the repository at this point in the history
* ENH: check for `_latex_repr_` typo
* MAINT: write latex test for `InvariantMass`
  • Loading branch information
redeboer committed Feb 5, 2024
1 parent 8822954 commit 84aa469
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
16 changes: 8 additions & 8 deletions src/ampform/kinematics/lorentz.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class EuclideanNorm(NumPyPrintable):
"""Take the euclidean norm of an array over axis 1."""

vector: sp.Basic
_latex_repr = R"\left|{vector}\right|"
_latex_repr_ = R"\left|{vector}\right|"

def evaluate(self) -> ArraySlice:
return sp.sqrt(EuclideanNormSquared(self.vector))
Expand Down Expand Up @@ -163,7 +163,7 @@ class InvariantMass(sp.Expr):
"""Invariant mass of a `.FourMomentumSymbol`."""

momentum: sp.Basic
_latex_repr = "m_{{{momentum}}}"
_latex_repr_ = "m_{{{momentum}}}"

def evaluate(self) -> ComplexSqrt:
p = self.momentum
Expand All @@ -176,7 +176,7 @@ class NegativeMomentum(sp.Expr):
r"""Invert the spatial components of a `.FourMomentumSymbol`."""

momentum: sp.Basic
_latex_repr = R"-\left({momentum}\right)"
_latex_repr_ = R"-\left({momentum}\right)"

def evaluate(self) -> sp.Expr:
p = self.momentum
Expand Down Expand Up @@ -261,7 +261,7 @@ class _BoostZMatrixImplementation(NumPyPrintable):
gamma_beta: sp.Basic
ones: _OnesArray
zeros: _ZerosArray
_latex_repr = R"\boldsymbol{{B_z}}\left({beta}\right)"
_latex_repr_ = R"\boldsymbol{{B_z}}\left({beta}\right)"

def _numpycode(self, printer: NumPyPrinter, *args) -> str:
printer.module_imports[printer._module].add("array")
Expand All @@ -281,7 +281,7 @@ class BoostMatrix(sp.Expr):
r"""Compute a rank-3 Lorentz boost matrix from a `.FourMomentumSymbol`."""

momentum: sp.Basic
_latex_repr = R"\boldsymbol{{B}}\left({momentum}\right)"
_latex_repr_ = R"\boldsymbol{{B}}\left({momentum}\right)"

def as_explicit(self) -> sp.MutableDenseMatrix:
momentum = self.momentum
Expand Down Expand Up @@ -349,7 +349,7 @@ class _BoostMatrixImplementation(NumPyPrintable):
b22: sp.Basic
b23: sp.Basic
b33: sp.Basic
_latex_repr = R"\boldsymbol{{B}}\left({momentum}\right)"
_latex_repr_ = R"\boldsymbol{{B}}\left({momentum}\right)"

def _numpycode(self, printer: NumPyPrinter, *args) -> str:
_, b00, b01, b02, b03, b11, b12, b13, b22, b23, b33 = self.args
Expand Down Expand Up @@ -405,7 +405,7 @@ class _RotationYMatrixImplementation(NumPyPrintable):
sin_angle: sp.Basic
ones: _OnesArray
zeros: _ZerosArray
_latex_repr = R"\boldsymbol{{R_y}}\left({angle}\right)"
_latex_repr_ = R"\boldsymbol{{R_y}}\left({angle}\right)"

def _numpycode(self, printer: NumPyPrinter, *args) -> str:
printer.module_imports[printer._module].add("array")
Expand Down Expand Up @@ -462,7 +462,7 @@ class _RotationZMatrixImplementation(NumPyPrintable):
sin_angle: sp.Basic
ones: _OnesArray
zeros: _ZerosArray
_latex_repr = R"\boldsymbol{{R_z}}\left({angle}\right)"
_latex_repr_ = R"\boldsymbol{{R_z}}\left({angle}\right)"

def _numpycode(self, printer: NumPyPrinter, *args) -> str:
printer.module_imports[printer._module].add("array")
Expand Down
9 changes: 7 additions & 2 deletions src/ampform/sympy/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,15 @@ def __call__(self, printer: LatexPrinter, *args) -> str: ...

@dataclass_transform(field_specifiers=(argument, _create_field))
def _implement_latex_repr(cls: type[T]) -> type[T]:
_latex_repr_: LatexMethod | str | None = getattr(cls, "_latex_repr_", None)
repr_name = "_latex_repr_"
repr_mistyped = "_latex_repr"
if hasattr(cls, repr_mistyped):
msg = f"Class defines a {repr_mistyped} attribute, but it should be {repr_name}"
raise AttributeError(msg)
_latex_repr_: LatexMethod | str | None = getattr(cls, repr_name, None)
if _latex_repr_ is None:
msg = (
"You need to define a _latex_repr_ str or method in order to decorate an"
f"You need to define a {repr_name} str or method in order to decorate an"
" unevaluated expression with a printer method for LaTeX representation."
)
raise NotImplementedError(msg)
Expand Down
6 changes: 6 additions & 0 deletions tests/kinematics/test_lorentz.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ def test_latex(self):


class TestInvariantMass:
def test_latex(self):
p = FourMomentumSymbol("p1", shape=[])
mass = InvariantMass(p)
latex = sp.latex(mass)
assert latex == "m_{p_{1}}"

@pytest.mark.parametrize(
("state_id", "expected_mass"),
[
Expand Down

0 comments on commit 84aa469

Please sign in to comment.