From d28271f26db735b1a7bcac3768e31e5f88654538 Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Mon, 21 Jun 2021 18:44:58 +0200 Subject: [PATCH] test: parametrize fixtures (#86) * test: add type hint pytest Config * test: improve use of Input class * test: improve qrules imports * test: parametrize fixture for HelicityModel * test: parametrize fixture for ReactionInfo * test: rename count_parameters to n_parameters --- tests/conftest.py | 44 ++++++++--------------------- tests/test_angular_distributions.py | 7 +++-- tests/test_dynamics.py | 36 +++++++++-------------- tests/test_helicity.py | 34 +++++++--------------- tests/test_parity_prefactor.py | 41 +++++++++++++-------------- 5 files changed, 58 insertions(+), 104 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index a9fd01daf..915c6479f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,12 @@ # pylint: disable=redefined-outer-name import logging -from typing import Dict +from typing import Dict, Tuple import numpy as np import pytest import qrules +from _pytest.config import Config +from _pytest.fixtures import SubRequest from qrules import ParticleCollection, ReactionInfo, load_default_particles from ampform import get_builder @@ -21,53 +23,31 @@ def particle_database() -> ParticleCollection: @pytest.fixture(scope="session") -def output_dir(pytestconfig) -> str: +def output_dir(pytestconfig: Config) -> str: return f"{pytestconfig.rootpath}/tests/output/" -@pytest.fixture(scope="session") -def jpsi_to_gamma_pi_pi_canonical_solutions() -> ReactionInfo: - return qrules.generate_transitions( - initial_state=[("J/psi(1S)", [-1, 1])], - final_state=["gamma", "pi0", "pi0"], - allowed_intermediate_particles=["f(0)(980)", "f(0)(1500)"], - allowed_interaction_types="strong only", - formalism="canonical-helicity", - ) - - -@pytest.fixture(scope="session") -def jpsi_to_gamma_pi_pi_helicity_solutions() -> ReactionInfo: +@pytest.fixture(scope="session", params=["canonical-helicity", "helicity"]) +def reaction(request: SubRequest) -> ReactionInfo: + formalism: str = request.param return qrules.generate_transitions( initial_state=[("J/psi(1S)", [-1, 1])], final_state=["gamma", "pi0", "pi0"], allowed_intermediate_particles=["f(0)(980)", "f(0)(1500)"], - allowed_interaction_types="strong only", - formalism="helicity", + allowed_interaction_types="strong", + formalism=formalism, ) @pytest.fixture(scope="session") -def jpsi_to_gamma_pi_pi_canonical_amplitude_model( - jpsi_to_gamma_pi_pi_canonical_solutions: ReactionInfo, -) -> HelicityModel: - return __create_model(jpsi_to_gamma_pi_pi_canonical_solutions) - - -@pytest.fixture(scope="session") -def jpsi_to_gamma_pi_pi_helicity_amplitude_model( - jpsi_to_gamma_pi_pi_helicity_solutions: ReactionInfo, -) -> HelicityModel: - return __create_model(jpsi_to_gamma_pi_pi_helicity_solutions) - - -def __create_model(reaction: ReactionInfo) -> HelicityModel: +def amplitude_model(reaction: ReactionInfo) -> Tuple[str, HelicityModel]: model_builder = get_builder(reaction) for name in reaction.get_intermediate_particles().names: model_builder.set_dynamics( name, create_relativistic_breit_wigner_with_ff ) - return model_builder.generate() + model = model_builder.generate() + return reaction.formalism, model # https://github.com/ComPWA/tensorwaves/blob/3d0ec44/tests/physics/helicity_formalism/test_helicity_angles.py#L61-L98 diff --git a/tests/test_angular_distributions.py b/tests/test_angular_distributions.py index d62ad042b..701fcc79a 100644 --- a/tests/test_angular_distributions.py +++ b/tests/test_angular_distributions.py @@ -7,6 +7,7 @@ import qrules import sympy as sp from qrules import ParticleCollection +from qrules.particle import Particle from ampform import get_builder @@ -57,13 +58,13 @@ def normalize( class TestEpemToDmD0Pip: @pytest.fixture(scope="class") def sympy_model(self, particle_database: ParticleCollection) -> sp.Expr: - epem = qrules.particle.Particle( + epem = Particle( name="EpEm", pid=12345678, mass=4.36, spin=1.0, - parity=qrules.particle.Parity(-1), - c_parity=qrules.particle.Parity(-1), + parity=-1, + c_parity=-1, ) particles = ParticleCollection(particle_database) particles.add(epem) diff --git a/tests/test_dynamics.py b/tests/test_dynamics.py index b681977af..0b587c8ba 100644 --- a/tests/test_dynamics.py +++ b/tests/test_dynamics.py @@ -1,35 +1,27 @@ # pylint: disable=no-self-use, too-many-arguments +from typing import Tuple + import numpy as np import pytest -import qrules import sympy as sp +from qrules import ParticleCollection from sympy import preorder_traversal from ampform.dynamics import ComplexSqrt from ampform.helicity import HelicityModel -@pytest.mark.parametrize( - ("formalism", "n_amplitudes", "n_parameters"), - [ - ("canonical", 16, 10), - ("helicity", 8, 8), - ], -) def test_generate( - formalism: str, - n_amplitudes: int, - n_parameters: int, - jpsi_to_gamma_pi_pi_canonical_amplitude_model: HelicityModel, - jpsi_to_gamma_pi_pi_helicity_amplitude_model: HelicityModel, - particle_database: qrules.ParticleCollection, + amplitude_model: Tuple[str, HelicityModel], + particle_database: ParticleCollection, ): - if formalism == "canonical": - model = jpsi_to_gamma_pi_pi_canonical_amplitude_model - elif formalism == "helicity": - model = jpsi_to_gamma_pi_pi_helicity_amplitude_model + formalism, model = amplitude_model + if formalism == "canonical-helicity": + n_amplitudes = 16 + n_parameters = 10 else: - raise NotImplementedError + n_amplitudes = 8 + n_parameters = 8 assert len(model.parameter_defaults) == n_parameters assert len(model.components) == 4 + n_amplitudes assert len(model.expression.free_symbols) == 7 + n_parameters @@ -81,14 +73,12 @@ def test_generate( expression = sp.piecewise_fold(expression) assert isinstance(expression, sp.Add) a1, a2 = tuple(map(str, expression.args)) - if formalism == "canonical": + if formalism == "canonical-helicity": assert a1 == "0.08/(-m**2 - 0.06*I*sqrt(m**2 - 0.07)/Abs(m) + 0.98)" assert a2 == "0.23/(-m**2 - 0.17*I*sqrt(m**2 - 0.07)/Abs(m) + 2.27)" - elif formalism == "helicity": + else: assert a1 == "0.17/(-m**2 - 0.17*I*sqrt(m**2 - 0.07)/Abs(m) + 2.27)" assert a2 == "0.06/(-m**2 - 0.06*I*sqrt(m**2 - 0.07)/Abs(m) + 0.98)" - else: - raise NotImplementedError def round_nested(expression: sp.Expr, n_decimals: int) -> sp.Expr: diff --git a/tests/test_helicity.py b/tests/test_helicity.py index 5bf35ed61..469409436 100644 --- a/tests/test_helicity.py +++ b/tests/test_helicity.py @@ -1,4 +1,3 @@ -import pytest import sympy as sp from qrules import ReactionInfo from sympy import cos, sin, sqrt @@ -6,26 +5,14 @@ from ampform import get_builder -@pytest.mark.parametrize( - ("formalism", "n_amplitudes", "n_parameters"), - [ - ("canonical", 16, 4), - ("helicity", 8, 2), - ], -) -def test_generate( - formalism: str, - n_amplitudes: int, - n_parameters: int, - jpsi_to_gamma_pi_pi_canonical_solutions: ReactionInfo, - jpsi_to_gamma_pi_pi_helicity_solutions: ReactionInfo, -): - if formalism == "canonical": - reaction = jpsi_to_gamma_pi_pi_canonical_solutions - elif formalism == "helicity": - reaction = jpsi_to_gamma_pi_pi_helicity_solutions +def test_generate(reaction: ReactionInfo): + if reaction.formalism == "canonical-helicity": + n_amplitudes = 16 + n_parameters = 4 else: - raise NotImplementedError + n_amplitudes = 8 + n_parameters = 2 + model = get_builder(reaction).generate() assert len(model.parameter_defaults) == n_parameters assert len(model.components) == 4 + n_amplitudes @@ -39,7 +26,8 @@ def test_generate( theta = sp.Symbol("theta", real=True) no_dynamics = no_dynamics.subs({existing_theta: theta}) no_dynamics = no_dynamics.trigsimp() - if formalism == "canonical": + + if reaction.formalism == "canonical-helicity": assert ( no_dynamics == 0.8 * sqrt(10) * cos(theta) ** 2 @@ -47,7 +35,5 @@ def test_generate( + 0.8 * sqrt(10) + 4.4 ) - elif formalism == "helicity": - assert no_dynamics == 8.0 - 4.0 * sin(theta) ** 2 else: - raise NotImplementedError + assert no_dynamics == 8.0 - 4.0 * sin(theta) ** 2 diff --git a/tests/test_parity_prefactor.py b/tests/test_parity_prefactor.py index e77e1cb78..a97b07519 100644 --- a/tests/test_parity_prefactor.py +++ b/tests/test_parity_prefactor.py @@ -1,7 +1,7 @@ from typing import NamedTuple import pytest -import qrules +from qrules import StateTransitionManager from ampform import get_builder @@ -14,44 +14,41 @@ class Input(NamedTuple): @pytest.mark.parametrize( - ("test_input", "parameter_count"), + ("test_input", "n_parameters"), [ ( Input( - [("Lambda(c)+", [0.5])], - ["p", "K-", "pi+"], - ["Lambda(1405)"], - [], + initial_state=[("Lambda(c)+", [0.5])], + final_state=["p", "K-", "pi+"], + intermediate_states=["Lambda(1405)"], + final_state_grouping=[], ), 2, ), ( Input( - [("Lambda(c)+", [0.5])], - ["p", "K-", "pi+"], - ["Delta(1232)++"], - [], + initial_state=[("Lambda(c)+", [0.5])], + final_state=["p", "K-", "pi+"], + intermediate_states=["Delta(1232)++"], + final_state_grouping=[], ), 2, ), ( Input( - [("Lambda(c)+", [0.5])], - ["p", "K-", "pi+"], - ["K*(892)0"], - [], + initial_state=[("Lambda(c)+", [0.5])], + final_state=["p", "K-", "pi+"], + intermediate_states=["K*(892)0"], + final_state_grouping=[], ), 4, ), ], ) -def test_parity_amplitude_coupling( - test_input: Input, - parameter_count: int, -) -> None: - stm = qrules.StateTransitionManager( - test_input.initial_state, - test_input.final_state, +def test_parity_amplitude_coupling(test_input: Input, n_parameters: int): + stm = StateTransitionManager( + initial_state=test_input.initial_state, + final_state=test_input.final_state, allowed_intermediate_particles=test_input.intermediate_states, number_of_threads=1, ) @@ -60,4 +57,4 @@ def test_parity_amplitude_coupling( model_builder = get_builder(reaction) amplitude_model = model_builder.generate() - assert len(amplitude_model.parameter_defaults) == parameter_count + assert len(amplitude_model.parameter_defaults) == n_parameters