diff --git a/.cspell.json b/.cspell.json index 61abb8b32..8fbe8e02d 100644 --- a/.cspell.json +++ b/.cspell.json @@ -190,6 +190,8 @@ "kmatrix", "kutschke", "kwargs", + "lambdifygenerated", + "lambdifying", "linestyle", "linewidth", "linkcheck", diff --git a/docs/_extend_docstrings.py b/docs/_extend_docstrings.py index c42baa26b..f21ad0762 100644 --- a/docs/_extend_docstrings.py +++ b/docs/_extend_docstrings.py @@ -19,7 +19,9 @@ import sympy as sp from sympy.printing.numpy import NumPyPrinter -from ampform.kinematics import FourMomentumSymbol +from ampform.kinematics import FourMomentumSymbol, _ArraySize +from ampform.sympy import NumPyPrintable +from ampform.sympy._array_expressions import ArrayMultiplication logging.getLogger().setLevel(logging.ERROR) @@ -60,8 +62,8 @@ def extend_BlattWeisskopfSquared() -> None: def extend_BoostZMatrix() -> None: from ampform.kinematics import BoostZMatrix - beta = sp.Symbol("beta") - expr = BoostZMatrix(beta) + beta, n_events = sp.symbols("beta n") + expr = BoostZMatrix(beta, n_events) _append_to_docstring( BoostZMatrix, f"""\n @@ -81,7 +83,45 @@ def extend_BoostZMatrix() -> None: """, ) b = sp.Symbol("b") - _append_code_rendering(BoostZMatrix(b)) + _append_code_rendering( + BoostZMatrix(b).doit(), + use_cse=True, + docstring_class=BoostZMatrix, + ) + + from ampform.kinematics import RotationYMatrix, RotationZMatrix + + _append_to_docstring( + BoostZMatrix, + """ + Note that this code was generated with :func:`sympy.lambdify + ` with :code:`cse=True`. The repetition + of :func:`numpy.ones` is still bothersome, but these sub-nodes is also + extracted by :func:`sympy.cse ` if the + expression is nested further down in an :doc:`expression tree + `, for instance when boosting a + `.FourMomentumSymbol` :math:`p` in the :math:`z`-direction: + """, + ) + p, beta, phi, theta = sp.symbols("p beta phi theta") + expr = ArrayMultiplication( + BoostZMatrix(beta, n_events=_ArraySize(p)), + RotationYMatrix(theta, n_events=_ArraySize(p)), + RotationZMatrix(phi, n_events=_ArraySize(p)), + p, + ) + _append_to_docstring( + BoostZMatrix, + f"""\n + .. math:: {sp.latex(expr)} + :label: boost-in-z-direction + + which in :mod:`numpy` code becomes: + """, + ) + _append_code_rendering( + expr.doit(), use_cse=True, docstring_class=BoostZMatrix + ) def extend_BreakupMomentumSquared() -> None: @@ -235,8 +275,8 @@ def extend_Phi() -> None: def extend_RotationYMatrix() -> None: from ampform.kinematics import RotationYMatrix - angle = sp.Symbol("alpha") - expr = RotationYMatrix(angle) + angle, n_events = sp.symbols("alpha n") + expr = RotationYMatrix(angle, n_events) _append_to_docstring( RotationYMatrix, f"""\n @@ -254,8 +294,8 @@ def extend_RotationYMatrix() -> None: def extend_RotationZMatrix() -> None: from ampform.kinematics import RotationZMatrix - angle = sp.Symbol("alpha") - expr = RotationZMatrix(angle) + angle, n_events = sp.symbols("alpha n") + expr = RotationZMatrix(angle, n_events) _append_to_docstring( RotationZMatrix, f"""\n @@ -276,7 +316,17 @@ def extend_RotationZMatrix() -> None: """, ) a = sp.Symbol("a") - _append_code_rendering(RotationZMatrix(a)) + _append_code_rendering( + RotationZMatrix(a).doit(), + use_cse=True, + docstring_class=RotationZMatrix, + ) + _append_to_docstring( + RotationZMatrix, + """ + See also the note that comes with Equation :eq:`boost-in-z-direction`. + """, + ) def extend_Theta() -> None: @@ -423,19 +473,42 @@ def extend_relativistic_breit_wigner_with_ff() -> None: ) -def _append_code_rendering(expr: sp.Expr) -> None: +def _append_code_rendering( + expr: NumPyPrintable, + use_cse: bool = False, + docstring_class: Optional[type] = None, +) -> None: printer = NumPyPrinter() - numpy_code = expr._numpycode(printer) + if use_cse: + args = sorted(expr.free_symbols, key=str) + func = sp.lambdify(args, expr, cse=True, printer=printer) + numpy_code = inspect.getsource(func) + else: + numpy_code = expr._numpycode(printer) import_statements = __print_imports(printer) - _append_to_docstring( - type(expr), - f"""\n - .. code:: - + if docstring_class is None: + docstring_class = type(expr) + numpy_code = textwrap.dedent(numpy_code) + numpy_code = textwrap.indent(numpy_code, prefix=8 * " ").strip() + options = "" + if ( + max(__get_text_width(import_statements), __get_text_width(numpy_code)) + > 90 + ): + options += ":class: full-width\n" + appended_text = f"""\n + .. code-block:: python + {options} {import_statements} {numpy_code} - """, - ) + """ + _append_to_docstring(docstring_class, appended_text) + + +def __get_text_width(text: str) -> int: + lines = text.split("\n") + widths = map(len, lines) + return max(widths) def _append_latex_doit_definition( diff --git a/docs/conf.py b/docs/conf.py index 0b8456595..ee3461a2a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -170,6 +170,9 @@ def fetch_logo(url: str, output_path: str) -> None: autodoc_typehints_format = "short" codeautolink_concat_default = True codeautolink_global_preface = """ +import numpy +import numpy as np +import sympy as sp from IPython.display import display """ AUTODOC_INSERT_SIGNATURE_LINEBREAKS = False diff --git a/src/ampform/kinematics.py b/src/ampform/kinematics.py index 2e15c6f8e..cae056663 100644 --- a/src/ampform/kinematics.py +++ b/src/ampform/kinematics.py @@ -4,7 +4,17 @@ import itertools import sys -from typing import TYPE_CHECKING, Any, Dict, List, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) import attr import sympy as sp @@ -334,19 +344,25 @@ def _latex(self, printer: LatexPrinter, *args: Any) -> str: return Rf"\theta\left({momentum}\right)" -class BoostZMatrix(NumPyPrintable): - """Represents a Lorentz boost matrix in the :math:`z`-direction.""" +@implement_doit_method +class BoostZMatrix(UnevaluatedExpression): + r"""Represents a Lorentz boost matrix in the :math:`z`-direction. - def __new__(cls, beta: sp.Expr, **kwargs: Any) -> "BoostZMatrix": - return create_expression(cls, beta, **kwargs) + Args: + beta: Velocity in the :math:`z`-direction, :math:`\beta=p_z/E`. + n_events: Number of events :math:`n` for this matrix array of shape + :math:`n\times4\times4`. Defaults to the `len` of :code:`beta`. + """ - @property - def beta(self) -> sp.Expr: - r"""Velocity in the :math:`z`-direction, :math:`\beta=p_z/E`.""" - return self.args[0] + def __new__( + cls, beta: sp.Expr, n_events: Optional[sp.Symbol] = None, **kwargs: Any + ) -> "BoostZMatrix": + if n_events is None: + n_events = _ArraySize(beta) + return create_expression(cls, beta, n_events, **kwargs) def as_explicit(self) -> sp.Expr: - beta = self.beta + beta = self.args[0] gamma = 1 / sp.sqrt(1 - beta**2) return sp.Matrix( [ @@ -357,42 +373,72 @@ def as_explicit(self) -> sp.Expr: ] ) + def evaluate(self) -> "_BoostZMatrixImplementation": + beta = self.args[0] + gamma = 1 / sp.sqrt(1 - beta**2) + n_events = self.args[1] + return _BoostZMatrixImplementation( + beta=beta, + gamma=gamma, + gamma_beta=gamma * beta, + ones=_OnesArray(n_events), + zeros=_ZerosArray(n_events), + ) + def _latex(self, printer: LatexPrinter, *args: Any) -> str: - beta = printer._print(self.beta) + return printer._print(self.evaluate(), *args) + + +class _BoostZMatrixImplementation(NumPyPrintable): + def __new__( # pylint: disable=too-many-arguments + cls, + beta: sp.Expr, + gamma: sp.Expr, + gamma_beta: sp.Expr, + ones: "_OnesArray", + zeros: "_ZerosArray", + **hints: Any, + ) -> "_BoostZMatrixImplementation": + return create_expression( + cls, beta, gamma, gamma_beta, ones, zeros, **hints + ) + + def _latex(self, printer: LatexPrinter, *args: Any) -> str: + beta = printer._print(self.args[0]) return Rf"\boldsymbol{{B_z}}\left({beta}\right)" def _numpycode(self, printer: NumPyPrinter, *args: Any) -> str: - printer.module_imports[printer._module].update( - {"array", "ones", "zeros", "sqrt"} - ) - beta = printer._print(self.beta) - gamma = f"1 / sqrt(1 - ({beta}) ** 2)" - n_events = f"len({beta})" - zeros = f"zeros({n_events})" - ones = f"ones({n_events})" + printer.module_imports[printer._module].add("array") + _, gamma, gamma_beta, ones, zeros = map(printer._print, self.args) return f"""array( [ - [{gamma}, {zeros}, {zeros}, -{gamma} * {beta}], + [{gamma}, {zeros}, {zeros}, -{gamma_beta}], [{zeros}, {ones}, {zeros}, {zeros}], [{zeros}, {zeros}, {ones}, {zeros}], - [-{gamma} * {beta}, {zeros}, {zeros}, {gamma}], + [-{gamma_beta}, {zeros}, {zeros}, {gamma}], ] ).transpose((2, 0, 1))""" -class RotationYMatrix(NumPyPrintable): - """Rotation matrix around the :math:`y`-axis for a `FourMomentumSymbol`.""" +@implement_doit_method +class RotationYMatrix(UnevaluatedExpression): + r"""Rotation matrix around the :math:`y`-axis for a `FourMomentumSymbol`. - def __new__(cls, angle: sp.Expr, **hints: Any) -> "RotationYMatrix": - return create_expression(cls, angle, **hints) + Args: + angle: Angle with which to rotate, see e.g. `Phi` and `Theta`. + n_events: Number of events :math:`n` for this matrix array of shape + :math:`n\times4\times4`. Defaults to the `len` of :code:`angle`. + """ - @property - def angle(self) -> sp.Expr: - """Angle with which to rotate, see e.g. `Phi` and `Theta`.""" - return self.args[0] + def __new__( + cls, angle: sp.Expr, n_events: Optional[sp.Symbol] = None, **hints: Any + ) -> "RotationYMatrix": + if n_events is None: + n_events = _ArraySize(angle) + return create_expression(cls, angle, n_events, **hints) def as_explicit(self) -> sp.Expr: - angle = self.angle + angle = self.args[0] return sp.Matrix( [ [1, 0, 0, 0], @@ -402,39 +448,69 @@ def as_explicit(self) -> sp.Expr: ] ) + def evaluate(self) -> "_RotationYMatrixImplementation": + angle = self.args[0] + n_events = self.args[1] + return _RotationYMatrixImplementation( + angle=angle, + cos_angle=sp.cos(angle), + sin_angle=sp.sin(angle), + ones=_OnesArray(n_events), + zeros=_ZerosArray(n_events), + ) + + def _latex(self, printer: LatexPrinter, *args: Any) -> str: + return printer._print(self.evaluate(), *args) + + +class _RotationYMatrixImplementation(NumPyPrintable): + def __new__( # pylint: disable=too-many-arguments + cls, + angle: sp.Expr, + cos_angle: sp.Expr, + sin_angle: sp.Expr, + ones: "_OnesArray", + zeros: "_ZerosArray", + **hints: Any, + ) -> "_RotationYMatrixImplementation": + return create_expression( + cls, angle, cos_angle, sin_angle, ones, zeros, **hints + ) + def _latex(self, printer: LatexPrinter, *args: Any) -> str: angle, *_ = self.args angle = printer._print(angle) return Rf"\boldsymbol{{R_y}}\left({angle}\right)" def _numpycode(self, printer: NumPyPrinter, *args: Any) -> str: - printer.module_imports[printer._module].update( - {"array", "cos", "ones", "zeros", "sin"} - ) - angle = printer._print(self.angle) - n_events = f"len({angle})" - zeros = f"zeros({n_events})" - ones = f"ones({n_events})" + printer.module_imports[printer._module].add("array") + _, cos_angle, sin_angle, ones, zeros = map(printer._print, self.args) return f"""array( [ [{ones}, {zeros}, {zeros}, {zeros}], - [{zeros}, cos({angle}), {zeros}, sin({angle})], + [{zeros}, {cos_angle}, {zeros}, {sin_angle}], [{zeros}, {zeros}, {ones}, {zeros}], - [{zeros}, -sin({angle}), {zeros}, cos({angle})], + [{zeros}, -{sin_angle}, {zeros}, {cos_angle}], ] ).transpose((2, 0, 1))""" -class RotationZMatrix(NumPyPrintable): - """Rotation matrix around the :math:`z`-axis for a `FourMomentumSymbol`.""" +@implement_doit_method +class RotationZMatrix(UnevaluatedExpression): + r"""Rotation matrix around the :math:`z`-axis for a `FourMomentumSymbol`. - def __new__(cls, angle: sp.Expr, **hints: Any) -> "RotationZMatrix": - return create_expression(cls, angle, **hints) + Args: + angle: Angle with which to rotate, see e.g. `Phi` and `Theta`. + n_events: Number of events :math:`n` for this matrix array of shape + :math:`n\times4\times4`. Defaults to the `len` of :code:`angle`. + """ - @property - def angle(self) -> sp.Expr: - """Angle with which to rotate, see e.g. `Phi` and `Theta`.""" - return self.args[0] + def __new__( + cls, angle: sp.Expr, n_events: Optional[sp.Symbol] = None, **hints: Any + ) -> "RotationZMatrix": + if n_events is None: + n_events = _ArraySize(angle) + return create_expression(cls, angle, n_events, **hints) def as_explicit(self) -> sp.Expr: angle = self.args[0] @@ -447,29 +523,86 @@ def as_explicit(self) -> sp.Expr: ] ) + def evaluate(self) -> "_RotationZMatrixImplementation": + angle = self.args[0] + n_events = self.args[1] + return _RotationZMatrixImplementation( + angle=angle, + cos_angle=sp.cos(angle), + sin_angle=sp.sin(angle), + ones=_OnesArray(n_events), + zeros=_ZerosArray(n_events), + ) + + def _latex(self, printer: LatexPrinter, *args: Any) -> str: + return printer._print(self.evaluate(), *args) + + +class _RotationZMatrixImplementation(NumPyPrintable): + def __new__( # pylint: disable=too-many-arguments + cls, + angle: sp.Expr, + cos_angle: sp.Expr, + sin_angle: sp.Expr, + ones: "_OnesArray", + zeros: "_ZerosArray", + **hints: Any, + ) -> "_RotationZMatrixImplementation": + return create_expression( + cls, angle, cos_angle, sin_angle, ones, zeros, **hints + ) + def _latex(self, printer: LatexPrinter, *args: Any) -> str: angle, *_ = self.args angle = printer._print(angle) return Rf"\boldsymbol{{R_z}}\left({angle}\right)" def _numpycode(self, printer: NumPyPrinter, *args: Any) -> str: - printer.module_imports[printer._module].update( - {"array", "cos", "ones", "zeros", "sin"} - ) - angle = printer._print(self.angle) - n_events = f"len({angle})" - zeros = f"zeros({n_events})" - ones = f"ones({n_events})" + printer.module_imports[printer._module].add("array") + _, cos_angle, sin_angle, ones, zeros = map(printer._print, self.args) return f"""array( [ [{ones}, {zeros}, {zeros}, {zeros}], - [{zeros}, cos({angle}), -sin({angle}), {zeros}], - [{zeros}, sin({angle}), cos({angle}), {zeros}], + [{zeros}, {cos_angle}, -{sin_angle}, {zeros}], + [{zeros}, {sin_angle}, {cos_angle}, {zeros}], [{zeros}, {zeros}, {zeros}, {ones}], ] ).transpose((2, 0, 1))""" +class _OnesArray(NumPyPrintable): + def __new__( + cls, shape: Union[int, Sequence[int]], **kwargs: Any + ) -> "_OnesArray": + return create_expression(cls, shape, **kwargs) + + def _numpycode(self, printer: NumPyPrinter, *args: Any) -> str: + printer.module_imports[printer._module].add("ones") + shape = printer._print(self.args[0]) + return f"ones({shape})" + + +class _ZerosArray(NumPyPrintable): + def __new__( + cls, shape: Union[int, Sequence[int]], **kwargs: Any + ) -> "_ZerosArray": + return create_expression(cls, shape, **kwargs) + + def _numpycode(self, printer: NumPyPrinter, *args: Any) -> str: + printer.module_imports[printer._module].add("zeros") + shape = printer._print(self.args[0]) + return f"zeros({shape})" + + +class _ArraySize(NumPyPrintable): + def __new__(cls, array: sp.Basic, **kwargs: Any) -> "_ArraySize": + return create_expression(cls, array, **kwargs) + + def _numpycode(self, printer: NumPyPrinter, *args: Any) -> str: + shape = printer._print(self.args[0]) + return f"len({shape})" + + def compute_helicity_angles( four_momenta: "FourMomenta", topology: Topology ) -> Dict[str, sp.Expr]: @@ -496,6 +629,8 @@ def compute_helicity_angles( f"final state edge IDs {set(topology.outgoing_edge_ids)}" ) + n_events = _get_number_of_events(four_momenta) + def __recursive_helicity_angles( # pylint: disable=too-many-locals four_momenta: FourMomenta, node_id: int ) -> Dict[str, sp.Expr]: @@ -533,9 +668,9 @@ def __recursive_helicity_angles( # pylint: disable=too-many-locals beta = p3_norm / Energy(four_momentum) new_momentum_pool = { k: ArrayMultiplication( - BoostZMatrix(beta), - RotationYMatrix(-theta), - RotationZMatrix(-phi), + BoostZMatrix(beta, n_events), + RotationYMatrix(-theta, n_events), + RotationZMatrix(-phi, n_events), p, ) for k, p in four_momenta.items() @@ -566,6 +701,13 @@ def __recursive_helicity_angles( # pylint: disable=too-many-locals ) +def _get_number_of_events( + four_momenta: "FourMomenta", +) -> "_ArraySize": + sorted_momentum_symbols = sorted(four_momenta.values(), key=str) + return _ArraySize(sorted_momentum_symbols[0]) + + def compute_invariant_masses( four_momenta: "FourMomenta", topology: Topology ) -> Dict[str, sp.Expr]: diff --git a/src/ampform/sympy/_array_expressions.py b/src/ampform/sympy/_array_expressions.py index b880b011c..dcee5f282 100644 --- a/src/ampform/sympy/_array_expressions.py +++ b/src/ampform/sympy/_array_expressions.py @@ -377,7 +377,7 @@ def _latex(self, printer: LatexPrinter, *args: Any) -> str: return " ".join(tensors) def _numpycode(self, printer: NumPyPrinter, *args: Any) -> str: - printer.module_imports[printer._module].update({"einsum", "transpose"}) + printer.module_imports[printer._module].add("einsum") tensors = list(map(printer._print, self.args)) if len(tensors) == 0: return "" diff --git a/tests/conftest.py b/tests/conftest.py index 63cb71afd..2f3c3dbec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ from _pytest.config import Config from _pytest.fixtures import SubRequest from qrules import ParticleCollection, ReactionInfo, load_default_particles +from qrules.settings import NumberOfThreads from ampform import get_builder from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff @@ -15,6 +16,10 @@ logging.getLogger().setLevel(level=logging.ERROR) +# Ensure consistent test coverage when running pytest multithreaded +# https://github.com/ComPWA/qrules/issues/11 +NumberOfThreads.set(1) + @pytest.fixture(scope="session") def particle_database() -> ParticleCollection: diff --git a/tests/test_kinematics.py b/tests/test_kinematics.py index 346328b34..1ec6ccc1d 100644 --- a/tests/test_kinematics.py +++ b/tests/test_kinematics.py @@ -1,5 +1,7 @@ # pylint: disable=no-member, no-self-use, redefined-outer-name # cspell:ignore atol doprint +import inspect +import textwrap from typing import Dict, Tuple import numpy as np @@ -19,14 +21,23 @@ FourMomentumZ, InvariantMass, Phi, + RotationYMatrix, + RotationZMatrix, Theta, ThreeMomentumNorm, + _ArraySize, + _OnesArray, + _ZerosArray, compute_helicity_angles, compute_invariant_masses, create_four_momentum_symbols, determine_attached_final_state, ) -from ampform.sympy._array_expressions import ArraySlice, ArraySymbol +from ampform.sympy._array_expressions import ( + ArrayMultiplication, + ArraySlice, + ArraySymbol, +) @pytest.fixture(scope="session") @@ -52,7 +63,9 @@ def helicity_angles( class TestBoostZMatrix: def test_boost_into_own_rest_frame_gives_mass(self): p = FourMomentumSymbol("p") - expr = BoostZMatrix(ThreeMomentumNorm(p) / Energy(p)) + n_events = _ArraySize(p) + beta = ThreeMomentumNorm(p) / Energy(p) + expr = BoostZMatrix(beta, n_events) func = sp.lambdify(p, expr.doit()) p_array = np.array([[5, 0, 0, 1]]) boost_z = func(p_array)[0] @@ -65,6 +78,48 @@ def test_boost_into_own_rest_frame_gives_mass(self): mass_array = func(p_array) assert pytest.approx(mass_array[0]) == mass + def test_numpycode_cse_in_expression_tree(self): + p, beta, phi, theta = sp.symbols("p beta phi theta") + expr = ArrayMultiplication( + BoostZMatrix(beta, n_events=_ArraySize(p)), + RotationYMatrix(theta, n_events=_ArraySize(p)), + RotationZMatrix(phi, n_events=_ArraySize(p)), + p, + ) + func = sp.lambdify([], expr.doit(), cse=True) + src = inspect.getsource(func) + expected_src = """ + def _lambdifygenerated(): + x0 = 1/sqrt(1 - beta**2) + x1 = len(p) + x2 = ones(x1) + x3 = zeros(x1) + return (einsum("...ij,...jk,...kl,...l->...i", array( + [ + [x0, x3, x3, -beta*x0], + [x3, x2, x3, x3], + [x3, x3, x2, x3], + [-beta*x0, x3, x3, x0], + ] + ).transpose((2, 0, 1)), array( + [ + [x2, x3, x3, x3], + [x3, cos(theta), x3, sin(theta)], + [x3, x3, x2, x3], + [x3, -sin(theta), x3, cos(theta)], + ] + ).transpose((2, 0, 1)), array( + [ + [x2, x3, x3, x3], + [x3, cos(phi), -sin(phi), x3], + [x3, sin(phi), cos(phi), x3], + [x3, x3, x3, x2], + ] + ).transpose((2, 0, 1)), p)) + """ + expected_src = textwrap.dedent(expected_src) + assert src.strip() == expected_src.strip() + class TestFourMomentumXYZ: def symbols( @@ -172,6 +227,129 @@ def test_numpy(self): ) +class TestRotationYMatrix: + @pytest.fixture(scope="session") + def rotation_expr(self): + angle, n_events = sp.symbols("a n") + return RotationYMatrix(angle, n_events) + + @pytest.fixture(scope="session") + def rotation_func(self, rotation_expr): + angle = sp.Symbol("a") + rotation_expr = rotation_expr.doit() + rotation_expr = rotation_expr.subs(sp.Symbol("n"), _ArraySize(angle)) + return sp.lambdify(angle, rotation_expr, cse=True) + + def test_numpycode_cse(self, rotation_expr: RotationYMatrix): + func = sp.lambdify([], rotation_expr.doit(), cse=True) + src = inspect.getsource(func) + expected_src = """ + def _lambdifygenerated(): + return (array( + [ + [ones(n), zeros(n), zeros(n), zeros(n)], + [zeros(n), cos(a), zeros(n), sin(a)], + [zeros(n), zeros(n), ones(n), zeros(n)], + [zeros(n), -sin(a), zeros(n), cos(a)], + ] + ).transpose((2, 0, 1))) + """ + expected_src = textwrap.dedent(expected_src) + assert src.strip() == expected_src.strip() + + def test_rotation_over_pi_flips_xz(self, rotation_func): + vectors = np.array([[1, 1, 1, 1]]) + angle_array = np.array([np.pi]) + rotated_vectors = np.einsum( + "...ij,...j->...j", rotation_func(angle_array), vectors + ) + assert pytest.approx(rotated_vectors) == np.array([[1, -1, 1, -1]]) + + +class TestRotationZMatrix: + @pytest.fixture(scope="session") + def rotation_expr(self): + angle, n_events = sp.symbols("a n") + return RotationZMatrix(angle, n_events) + + @pytest.fixture(scope="session") + def rotation_func(self, rotation_expr): + angle = sp.Symbol("a") + rotation_expr = rotation_expr.doit() + rotation_expr = rotation_expr.subs(sp.Symbol("n"), _ArraySize(angle)) + return sp.lambdify(angle, rotation_expr, cse=True) + + def test_numpycode_cse(self, rotation_expr: RotationZMatrix): + func = sp.lambdify([], rotation_expr.doit(), cse=True) + src = inspect.getsource(func) + expected_src = """ + def _lambdifygenerated(): + return (array( + [ + [ones(n), zeros(n), zeros(n), zeros(n)], + [zeros(n), cos(a), -sin(a), zeros(n)], + [zeros(n), sin(a), cos(a), zeros(n)], + [zeros(n), zeros(n), zeros(n), ones(n)], + ] + ).transpose((2, 0, 1))) + """ + expected_src = textwrap.dedent(expected_src) + assert src.strip() == expected_src.strip() + + def test_rotation_over_pi_flips_xy(self, rotation_func): + vectors = np.array([[1, 1, 1, 1]]) + angle_array = np.array([np.pi]) + rotated_vectors = np.einsum( + "...ij,...j->...j", rotation_func(angle_array), vectors + ) + assert pytest.approx(rotated_vectors) == np.array([[1, -1, -1, 1]]) + + +@pytest.mark.parametrize("rotation", [RotationYMatrix, RotationZMatrix]) +def test_rotation_latex_repr_is_identical_with_doit(rotation): + angle, n_events = sp.symbols("a n") + expr = rotation(angle, n_events) + assert sp.latex(expr) == sp.latex(expr.doit()) + + +@pytest.mark.parametrize("rotation", [RotationYMatrix, RotationZMatrix]) +def test_rotation_over_multiple_two_pi_is_identity(rotation): + angle = sp.Symbol("a") + expr = rotation(angle) + func = sp.lambdify(angle, expr.doit(), cse=True) + angle_array = np.arange(-2, 4, 1) * 2 * np.pi + rotation_matrices = func(angle_array) + identity = np.array( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ] + ) + identity = np.tile(identity, reps=(len(angle_array), 1, 1)) + assert pytest.approx(rotation_matrices) == identity + + +class TestOnesZerosArray: + @pytest.mark.parametrize("array_type", ["ones", "zeros"]) + @pytest.mark.parametrize("shape", [10, (4, 2), [3, 5, 7]]) + def test_numpycode(self, array_type, shape): + if array_type == "ones": + expr_class = _OnesArray + array_func = np.ones + elif array_type == "zeros": + expr_class = _ZerosArray + array_func = np.zeros + else: + raise NotImplementedError + array_expr = expr_class(shape) + create_array = sp.lambdify([], array_expr) + array = create_array() + np.testing.assert_array_equal(array, array_func(shape)) + + +@pytest.mark.parametrize("use_cse", [False, True]) @pytest.mark.parametrize( ("angle_name", "expected_values"), [ @@ -279,7 +457,8 @@ def test_numpy(self): ), ], ) -def test_compute_helicity_angles( +def test_compute_helicity_angles( # pylint: disable=too-many-arguments + use_cse: bool, data_sample: Dict[int, np.ndarray], topology_and_momentum_symbols: Tuple[Topology, FourMomenta], angle_name: str, @@ -289,7 +468,7 @@ def test_compute_helicity_angles( _, momentum_symbols = topology_and_momentum_symbols four_momenta = data_sample.values() expr = helicity_angles[angle_name] - np_angle = sp.lambdify(momentum_symbols.values(), expr.doit()) + np_angle = sp.lambdify(momentum_symbols.values(), expr.doit(), cse=use_cse) computed = np_angle(*four_momenta) np.testing.assert_allclose(computed, expected_values, atol=1e-5)